Beispiel #1
0
    def test_hypergraph_mask_gen(self):
        molecules = True
        grammar_cache = 'tmp.pickle'
        grammar = 'hypergraph:' + grammar_cache
        # create a grammar cache inferred from our sample molecules
        g = HypergraphGrammar(cache_file=grammar_cache)
        if os.path.isfile(g.cache_file):
            os.remove(g.cache_file)
        g.strings_to_actions(get_zinc_smiles(5))
        mask_gen1 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen2 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen1.priors = False
        mask_gen2.priors = True
        policy1 = SoftmaxRandomSamplePolicy(
            bias=mask_gen1.grammar.get_log_frequencies())
        policy2 = SoftmaxRandomSamplePolicy()
        lp = []
        for mg in [mask_gen1, mask_gen2]:
            mg.reset()
            mg.apply_action([None])
            logit_priors = mg.action_prior_logits()  # that includes any priors
            lp.append(
                torch.from_numpy(logit_priors).to(device=device,
                                                  dtype=torch.float32))

        dummy_model_output = torch.ones_like(lp[0])
        eff_logits = []
        for this_lp, policy in zip(lp, [policy1, policy2]):
            eff_logits.append(policy.effective_logits(dummy_model_output))

        assert torch.max((eff_logits[0] - eff_logits[1]).abs()) < 1e-6
 def test_classic_mask_gen_equations(self):
     molecules = False
     grammar = 'classic'
     codec = get_codec(molecules, grammar, max_seq_length)
     actions = run_random_gen(codec.mask_gen)
     all_eqs = codec.actions_to_strings(actions)
     # the only way of testing correctness we have here is whether the equations parse correctly
     parsed_eqs = codec.strings_to_actions(all_eqs)
 def test_classic_mask_gen_molecules(self):
     molecules = True
     grammar = 'classic'
     codec = get_codec(molecules, grammar, max_seq_length)
     actions = run_random_gen(codec.mask_gen)
     new_smiles = codec.actions_to_strings(actions)
     # the SMILES produced by that grammar are NOT guaranteed to be valid,
     # so can only check that the decoding completes without errors and is grammatically valid
     parsed_smiles = codec.strings_to_actions(new_smiles)
def get_node_decoder(grammar,
                     max_seq_length=15,
                     drop_rate=0.0,
                     decoder_type='attn',
                     rule_policy=None,
                     reward_fun=lambda x: -1 * np.ones(len(x)),
                     batch_size=None,
                     priors='conditional',
                     bins=10):

    codec = get_codec(True, grammar, max_seq_length)
    assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
    if 'attn' in decoder_type:
        model_type = 'transformer'
    elif 'rnn' in decoder_type:
        model_type = 'rnn'

    if 'distr' in decoder_type:
        if 'softmax' in decoder_type:
            output_type = 'distributions_softmax'
        else:
            output_type = 'distributions_thompson'
    else:
        output_type = 'values'

    model = get_graph_model(codec,
                            drop_rate,
                            model_type,
                            output_type,
                            num_bins=bins)
    # encoder = GraphEncoder(grammar=codec.grammar,
    #                        d_model=512,
    #                        drop_rate=drop_rate,
    #                        model_type=model_type)
    #
    #
    # model = MultipleOutputHead(model=encoder,
    #                            output_spec={'node': 1,  # to be used to select next node to expand
    #                                         'action': codec.feature_len()},  # to select the action for chosen node
    #                            drop_rate=drop_rate)

    # don't support using this model in VAE-style models yet

    mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                       grammar=codec.grammar)
    mask_gen.priors = priors
    if rule_policy is None:
        rule_policy = SoftmaxRandomSamplePolicy()

    stepper = GraphDecoderWithNodeSelection(model, rule_policy=rule_policy)
    env = GraphEnvironment(mask_gen,
                           reward_fun=reward_fun,
                           batch_size=batch_size)
    decoder = DecoderWithEnvironmentNew(stepper, env)

    return decoder, stepper
def get_node_decoder(grammar,
                     max_seq_length=15,
                     drop_rate=0.0,
                     decoder_type='attn',
                     rule_policy=None,
                     reward_fun=lambda x: -1 * np.ones(len(x)),
                     batch_size=None,
                     priors='conditional',
                     bins=10):

    codec = get_codec(True, grammar, max_seq_length)
    assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
    if 'attn' in decoder_type:
        model_type = 'transformer'
    elif 'rnn' in decoder_type:
        model_type = 'rnn'
    elif 'conditional' in decoder_type:
        if 'sparse' in decoder_type:
            model_type = 'conditional_sparse'
        else:
            model_type = 'conditional'

    if 'distr' in decoder_type:
        if 'softmax' in decoder_type:
            output_type = 'distributions_softmax'
        else:
            output_type = 'distributions_thompson'
    else:
        output_type = 'values'

    model = get_graph_model(codec,
                            drop_rate,
                            model_type,
                            output_type,
                            num_bins=bins)

    if model_type == 'conditional_sparse':
        priors = 'term_dist_only'

    mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                       grammar=codec.grammar,
                                       priors=priors)
    mask_gen.priors = priors
    if rule_policy is None:
        rule_policy = SoftmaxRandomSamplePolicySparse(
        ) if 'sparse' in decoder_type else SoftmaxRandomSamplePolicy()

    stepper_type = GraphDecoderWithNodeSelectionSparse if 'sparse' in decoder_type else GraphDecoderWithNodeSelection
    stepper = stepper_type(model, rule_policy=rule_policy)
    env = GraphEnvironment(mask_gen,
                           reward_fun=reward_fun,
                           batch_size=batch_size)
    decoder = DecoderWithEnvironmentNew(stepper, env)

    return decoder, stepper
 def test_hypergraph_mask_gen(self):
     molecules = True
     grammar_cache = 'tmp.pickle'
     grammar = 'hypergraph:' + grammar_cache
     # create a grammar cache inferred from our sample molecules
     g = HypergraphGrammar(cache_file=grammar_cache)
     if os.path.isfile(g.cache_file):
         os.remove(g.cache_file)
     g.strings_to_actions(smiles)
     codec = get_codec(molecules, grammar, max_seq_length)
     self.generate_and_validate(codec)
 def test_graph_encoder_with_head(self):
     codec = get_codec(molecules=True,
                       grammar='hypergraph:' + tmp_file,
                       max_seq_length=max_seq_length)
     encoder = GraphEncoder(grammar=gi.grammar,
                            d_model=512,
                            drop_rate=0.0)
     mol_graphs = [HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)]
     model = MultipleOutputHead(model=encoder,
                                output_spec={'node': 1,  # to be used to select next node to expand
                                             'action': codec.feature_len()},  # to select the action for chosen node
                                drop_rate=0.1).to(device)
     out = model(mol_graphs)
def get_vae(molecules=True,
            grammar=True,
            weights_file=None,
            epsilon_std=1,
            decoder_type='step',
            **kwargs):
    model_args = get_model_args(molecules=molecules, grammar=grammar)
    for key, value in kwargs.items():
        if key in model_args:
            model_args[key] = value
    sample_z = model_args.pop('sample_z')

    encoder_args = [
        'feature_len', 'max_seq_length', 'cnn_encoder_params', 'drop_rate',
        'encoder_type', 'rnn_encoder_hidden_n'
    ]
    encoder = get_encoder(**{
        key: value
        for key, value in model_args.items() if key in encoder_args
    })

    decoder_args = [
        'z_size', 'decoder_hidden_n', 'feature_len', 'max_seq_length',
        'drop_rate', 'batch_size'
    ]
    decoder, _ = get_decoder(molecules,
                             grammar,
                             decoder_type=decoder_type,
                             **{
                                 key: value
                                 for key, value in model_args.items()
                                 if key in decoder_args
                             })

    model = generative_playground.models.heads.vae.VariationalAutoEncoderHead(
        encoder=encoder,
        decoder=decoder,
        sample_z=sample_z,
        epsilon_std=epsilon_std,
        z_size=model_args['z_size'])

    if weights_file is not None:
        model.load(weights_file)

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules,
                      grammar,
                      max_seq_length=settings['max_seq_length'])
    # codec.set_model(model)  # todo do we ever use this?
    return model, codec
 def generic_decoder_test(self, decoder_type, grammar):
     codec = get_codec(molecules=True,
                       grammar=grammar,
                       max_seq_length=max_seq_length)
     decoder, pre_decoder = get_decoder(decoder_type=decoder_type,
                                        max_seq_length=max_seq_length,
                                        grammar=grammar,
                                        feature_len=codec.feature_len(),
                                        z_size=z_size,
                                        batch_size=batch_size)
     out = decoder()
     # it returns all sorts of things: out_actions_all, out_logits_all, out_rewards_all, out_terminals_all, (info[0], to_pytorch(info[1]))
     all_sum = torch.sum(out['logits'])
     all_sum.backward()
     return all_sum
Beispiel #10
0
def get_thompson_globals(
        num_bins=50,  # TODO: replace with a Value Distribution object
        reward_fun_=None,
        grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle',  # 'hyper_grammar.pickle'
        max_seq_length=60,
        decay=0.95,
        updates_to_refresh=10):
    grammar_name = 'hypergraph:' + grammar_cache
    codec = get_codec(True, grammar_name, max_seq_length)
    reward_proc = RewardProcessor(num_bins)

    rule_choice_repo_factory = lambda x: RuleChoiceRepository(
        reward_proc=reward_proc, mask=x, decay=decay)

    exp_repo_ = ExperienceRepository(
        grammar=codec.grammar,
        reward_preprocessor=reward_proc,
        decay=decay,
        conditional_keys=[
            key for key in codec.grammar.conditional_frequencies.keys()
        ],
        rule_choice_repo_factory=rule_choice_repo_factory)

    # TODO: weave this into the nodes to do node-level action averages as regularization
    local_exp_repo_factory = lambda graph: ExperienceRepository(
        grammar=codec.grammar,
        reward_preprocessor=reward_proc,
        decay=decay,
        conditional_keys=[i for i in range(len(graph))],
        rule_choice_repo_factory=rule_choice_repo_factory)

    globals = GlobalParametersThompson(
        codec.grammar,
        max_seq_length,
        exp_repo_,
        decay=decay,
        updates_to_refresh=updates_to_refresh,
        reward_fun=reward_fun_,
        reward_proc=reward_proc,
        rule_choice_repo_factory=rule_choice_repo_factory,
        state_store=None)

    return globals
Beispiel #11
0
    def __init__(
            self,
            batch_size=20,
            reward_fun_=None,
            grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle',  # 'hyper_grammar.pickle'
            max_depth=60,
            lr=0.05,
            grad_clip=5,
            entropy_weight=3,
            decay=None,
            num_bins=None,
            updates_to_refresh=None,
            plotter=None,
            degenerate=False  # use a null model if true
    ):
        grammar_name = 'hypergraph:' + grammar_cache
        codec = get_codec(True, grammar_name, max_depth)
        super().__init__(codec.grammar,
                         max_depth,
                         reward_fun_, {},
                         plotter=plotter)

        if not degenerate:
            # create optimizer factory
            optimizer_factory = optimizer_factory_gen(lr, grad_clip)
            # create model
            model = CondtionalProbabilityModel(codec.grammar).to(device)
            # create loss object
            loss_type = 'advantage_record'
            loss_fun = PolicyGradientLoss(loss_type,
                                          entropy_wgt=entropy_weight)
            self.model = model
            self.process_reward = MCTSRewardProcessor(loss_fun, model,
                                                      optimizer_factory,
                                                      batch_size)
        else:

            self.model = PassthroughModel()
            self.process_reward = lambda reward, log_ps, actions, params: None

        self.decay = decay
        self.reward_proc = RewardProcessor(num_bins)
        self.updates_to_refresh = updates_to_refresh
def get_model_args(molecules,
                   grammar,
                   drop_rate=0.5,
                   sample_z=False,
                   encoder_type='rnn'):
    settings = get_settings(molecules, grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    model_args = {
        'z_size': settings['z_size'],
        'decoder_hidden_n': settings['decoder_hidden_n'],
        'feature_len': codec.feature_len(),
        'max_seq_length': settings['max_seq_length'],
        'cnn_encoder_params': settings['cnn_encoder_params'],
        'drop_rate': drop_rate,
        'sample_z': sample_z,
        'encoder_type': encoder_type,
        'rnn_encoder_hidden_n': settings['rnn_encoder_hidden_n']
    }

    return model_args
Beispiel #13
0
 def __init__(self,
              molecules=True,
              grammar=True,
              reward_fun=None,
              batch_size=1,
              max_steps=None,
              save_dataset=None):
     settings = get_settings(molecules, grammar)
     self.codec = get_codec(molecules, grammar, settings['max_seq_length'])
     self.action_dim = self.codec.feature_len()
     self.state_dim = self.action_dim
     if max_steps is None:
         self._max_episode_steps = settings['max_seq_length']
     else:
         self._max_episode_steps = max_steps
     self.reward_fun = reward_fun
     self.batch_size = batch_size
     self.save_dataset = save_dataset
     self.smiles = None
     self.seq_len = None
     self.valid = None
     self.actions = None
     self.done_rewards = None
     self.reset()
def train_policy_gradient(molecules=True,
                          grammar=True,
                          smiles_source='ZINC',
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          lr_schedule=None,
                          discrim_wt=2,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file_root_name=None,
                          reward_sm=0.0,
                          preload_file_root_name=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          priors=True,
                          node_temperature_schedule=lambda x: 1.0,
                          rule_temperature_schedule=lambda x: 1.0,
                          eps=0.0,
                          half_float=False,
                          extra_repetition_penalty=0.0,
                          entropy_wgt=1.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    alt_reward_calc = AdjustedRewardCalculator(reward_fun_on,
                                               zinc_set,
                                               lookbacks,
                                               extra_repetition_penalty,
                                               discrim_wt,
                                               discrim_model=None)

    reward_fun = lambda x: adj_reward(discrim_wt,
                                      discrim_model,
                                      reward_fun_on,
                                      zinc_set,
                                      history_data,
                                      extra_repetition_penalty,
                                      x,
                                      alt_calc=alt_reward_calc)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(2.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=reward_fun,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      lr_schedule=lr_schedule,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm,
                save_location=os.path.dirname(save_path))
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(
            on_policy_loss_type,
            entropy_wgt=entropy_wgt),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
    def __init__(self,
                 grammar,
                 smiles_source='ZINC',
                 BATCH_SIZE=None,
                 reward_fun=None,
                 max_steps=277,
                 num_batches=100,
                 lr=2e-4,
                 entropy_wgt=1.0,
                 lr_schedule=None,
                 root_name=None,
                 preload_file_root_name=None,
                 save_location=None,
                 plot_metrics=True,
                 metric_smooth=0.0,
                 decoder_type='graph_conditional',
                 on_policy_loss_type='advantage_record',
                 priors='conditional',
                 rule_temperature_schedule=None,
                 eps=0.0,
                 half_float=False,
                 extra_repetition_penalty=0.0):

        self.num_batches = num_batches
        self.save_location = save_location
        self.molecule_saver = MoleculeSaver(None, gzip=True)
        self.metric_monitor = None  # to be populated by self.set_root_name(...)

        zinc_data = get_smiles_from_database(source=smiles_source)
        zinc_set = set(zinc_data)
        lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
        history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

        if root_name is not None:
            pass
            # gen_save_file = root_name + '_gen.h5'
        if preload_file_root_name is not None:
            gen_preload_file = preload_file_root_name + '_gen.h5'

        settings = get_settings(molecules=True, grammar=grammar)
        codec = get_codec(True, grammar, settings['max_seq_length'])

        if BATCH_SIZE is not None:
            settings['BATCH_SIZE'] = BATCH_SIZE

        self.alt_reward_calc = AdjustedRewardCalculator(
            reward_fun,
            zinc_set,
            lookbacks,
            extra_repetition_penalty,
            0,
            discrim_model=None)
        self.reward_fun = lambda x: adj_reward(0,
                                               None,
                                               reward_fun,
                                               zinc_set,
                                               history_data,
                                               extra_repetition_penalty,
                                               x,
                                               alt_calc=self.alt_reward_calc)

        task = SequenceGenerationTask(molecules=True,
                                      grammar=grammar,
                                      reward_fun=self.alt_reward_calc,
                                      batch_size=BATCH_SIZE,
                                      max_steps=max_steps,
                                      save_dataset=None)

        if 'sparse' in decoder_type:
            rule_policy = SoftmaxRandomSamplePolicySparse()
        else:
            rule_policy = SoftmaxRandomSamplePolicy(
                temperature=torch.tensor(1.0), eps=eps)

        # TODO: strip this down to the normal call
        self.model = get_decoder(True,
                                 grammar,
                                 z_size=settings['z_size'],
                                 decoder_hidden_n=200,
                                 feature_len=codec.feature_len(),
                                 max_seq_length=max_steps,
                                 batch_size=BATCH_SIZE,
                                 decoder_type=decoder_type,
                                 reward_fun=self.alt_reward_calc,
                                 task=task,
                                 rule_policy=rule_policy,
                                 priors=priors)[0]

        if preload_file_root_name is not None:
            try:
                preload_path = os.path.realpath(save_location +
                                                gen_preload_file)
                self.model.load_state_dict(torch.load(preload_path,
                                                      map_location='cpu'),
                                           strict=False)
                print('Generator weights loaded successfully!')
            except Exception as e:
                print('failed to load generator weights ' + str(e))

        # construct the loader to feed the discriminator
        def make_callback(data):
            def hc(inputs, model, outputs, loss_fn, loss):
                graphs = outputs['graphs']
                smiles = [g.to_smiles() for g in graphs]
                for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                    if s not in data:
                        data.append(s)

            return hc

        if plot_metrics:
            # TODO: save_file for rewards data goes here?
            self.metric_monitor_factory = lambda name: MetricPlotter(
                plot_prefix='',
                loss_display_cap=float('inf'),
                dashboard_name=name,
                save_location=save_location,
                process_model_fun=model_process_fun,
                smooth_weight=metric_smooth)
        else:
            self.metric_monitor_factory = lambda x: None

        # the on-policy fitter

        gen_extra_callbacks = [make_callback(d) for d in history_data]
        gen_extra_callbacks.append(self.molecule_saver)
        if rule_temperature_schedule is not None:
            gen_extra_callbacks.append(
                TemperatureCallback(rule_policy, rule_temperature_schedule))

        nice_params = filter(lambda p: p.requires_grad,
                             self.model.parameters())
        self.optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)

        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_schedule)
        self.loss = PolicyGradientLoss(on_policy_loss_type,
                                       entropy_wgt=entropy_wgt)
        self.fitter_factory = lambda: make_fitter(BATCH_SIZE, settings[
            'z_size'], [self.metric_monitor] + gen_extra_callbacks, self)

        self.fitter = self.fitter_factory()
        self.set_root_name(root_name)
        print('Runner initialized!')
def get_decoder(
        molecules=True,
        grammar=True,
        z_size=200,
        decoder_hidden_n=200,
        feature_len=12,  # TODO: remove this
        max_seq_length=15,
        drop_rate=0.0,
        decoder_type='step',
        task=None,
        batch_size=None):

    codec = get_codec(molecules, grammar, max_seq_length)

    if decoder_type == 'old':
        stepper = ResettingRNNDecoder(z_size=z_size,
                                      hidden_n=decoder_hidden_n,
                                      feature_len=codec.feature_len(),
                                      max_seq_length=max_seq_length,
                                      steps=max_seq_length,
                                      drop_rate=drop_rate)
        stepper = OneStepDecoderContinuous(stepper)
    elif decoder_type == 'attn_graph':
        assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
        encoder = GraphEncoder(grammar=codec.grammar,
                               d_model=512,
                               drop_rate=drop_rate)

        model = MultipleOutputHead(
            model=encoder,
            output_spec={
                'node': 1,  # to be used to select next node to expand
                'action': codec.feature_len()
            },  # to select the action for chosen node
            drop_rate=drop_rate)

        # don't support using this model in VAE-style scenarios yet
        model.init_encoder_output = lambda x: None

        mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                           grammar=codec.grammar)

        stepper = GraphDecoder(model=model, mask_gen=mask_gen)

    else:
        if decoder_type == 'step':
            stepper = SimpleRNNDecoder(z_size=z_size,
                                       hidden_n=decoder_hidden_n,
                                       feature_len=codec.feature_len(),
                                       max_seq_length=max_seq_length,
                                       drop_rate=drop_rate,
                                       use_last_action=False)

        elif decoder_type == 'action':
            stepper = SimpleRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'action_resnet':
            stepper = ResNetRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'attention':
            stepper = SelfAttentionDecoderStep(num_actions=codec.feature_len(),
                                               max_seq_len=max_seq_length,
                                               drop_rate=drop_rate,
                                               enc_output_size=z_size)
        elif decoder_type == 'random':
            stepper = RandomDecoder(feature_len=codec.feature_len(),
                                    max_seq_length=max_seq_length)
        else:
            raise NotImplementedError('Unknown decoder type: ' +
                                      str(decoder_type))

    if grammar is not False:
        # add a masking layer
        mask_gen = get_codec(molecules, grammar, max_seq_length).mask_gen
        stepper = MaskingHead(stepper, mask_gen)

    policy = SoftmaxRandomSamplePolicy(
    )  #bias=codec.grammar.get_log_frequencies())

    decoder = to_gpu(
        SimpleDiscreteDecoderWithEnv(
            stepper, policy, task=task,
            batch_size=batch_size))  # , bypass_actions=True))

    return decoder, stepper
Beispiel #17
0
def train_policy_gradient_ppo(molecules=True,
                              grammar=True,
                              smiles_source='ZINC',
                              EPOCHS=None,
                              BATCH_SIZE=None,
                              reward_fun_on=None,
                              reward_fun_off=None,
                              max_steps=277,
                              lr_on=2e-4,
                              lr_discrim=1e-4,
                              discrim_wt=2,
                              p_thresh=0.5,
                              drop_rate=0.0,
                              plot_ignore_initial=0,
                              randomize_reward=False,
                              save_file_root_name=None,
                              reward_sm=0.0,
                              preload_file_root_name=None,
                              anchor_file=None,
                              anchor_weight=0.0,
                              decoder_type='action',
                              plot_prefix='',
                              dashboard='policy gradient',
                              smiles_save_file=None,
                              on_policy_loss_type='best',
                              priors=True,
                              node_temperature_schedule=lambda x: 1.0,
                              rule_temperature_schedule=lambda x: 1.0,
                              eps=0.0,
                              half_float=False,
                              extra_repetition_penalty=0.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def apply_originality_penalty(x, orig_mult):
        assert x <= 1, "Reward must be no greater than 0"
        if x > 0.5:  # want to punish nearly-perfect scores less and less
            out = math.pow(x, 1 / orig_mult)
        else:  # continuous join at 0.5
            penalty = math.pow(0.5, 1 / orig_mult) - 0.5
            out = x + penalty

        out -= extra_repetition_penalty * (1 - 1 / orig_mult)
        return out

    def adj_reward(x):
        if discrim_wt > 1e-5:
            p = discriminator_reward_mult(x)
        else:
            p = 0
        rwd = np.array(reward_fun_on(x))
        orig_mult = originality_mult(x)
        # we assume the reward is <=1, first term will dominate for reward <0, second for 0 < reward < 1
        # reward = np.minimum(rwd/orig_mult, np.power(np.abs(rwd),1/orig_mult))
        reward = np.array([
            apply_originality_penalty(x, om) for x, om in zip(rwd, orig_mult)
        ])
        out = reward + discrim_wt * p * orig_mult
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=adj_reward,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import numpy as np
    scorer = NormalizedScorer()

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(on_policy_loss_type),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
 def test_custom_grammar_mask_gen(self):
     molecules = True
     grammar = 'new'
     codec = get_codec(molecules, grammar, max_seq_length)
     self.generate_and_validate(codec)
Beispiel #19
0
 def check_codec(self, input, molecules, grammar):
     settings = get_settings(molecules, grammar)
     codec = get_codec(molecules, grammar, settings['max_seq_length'])
     actions = codec.strings_to_actions([input])
     re_input = codec.actions_to_strings(actions)
     self.assertEqual(input, re_input[0])
Beispiel #20
0
def get_decoder(
        molecules=True,
        grammar=True,
        z_size=200,
        decoder_hidden_n=200,
        feature_len=12,  # TODO: remove this
        max_seq_length=15,
        drop_rate=0.0,
        decoder_type='step',
        task=None,
        node_policy=None,
        rule_policy=None,
        reward_fun=lambda x: -1 * np.ones(len(x)),
        batch_size=None,
        priors=True):
    codec = get_codec(molecules, grammar, max_seq_length)

    if decoder_type == 'old':
        stepper = ResettingRNNDecoder(z_size=z_size,
                                      hidden_n=decoder_hidden_n,
                                      feature_len=codec.feature_len(),
                                      max_seq_length=max_seq_length,
                                      steps=max_seq_length,
                                      drop_rate=drop_rate)
        stepper = OneStepDecoderContinuous(stepper)
    elif 'graph' in decoder_type and decoder_type not in [
            'attn_graph', 'rnn_graph'
    ]:
        return get_node_decoder(grammar, max_seq_length, drop_rate,
                                decoder_type, rule_policy, reward_fun,
                                batch_size, priors)

    elif decoder_type in ['attn_graph', 'rnn_graph']:  # deprecated
        assert 'hypergraph' in grammar, "Only the hypergraph grammar can be used with attn_graph decoder type"
        if 'attn' in decoder_type:
            encoder = GraphEncoder(grammar=codec.grammar,
                                   d_model=512,
                                   drop_rate=drop_rate,
                                   model_type='transformer')
        elif 'rnn' in decoder_type:
            encoder = GraphEncoder(grammar=codec.grammar,
                                   d_model=512,
                                   drop_rate=drop_rate,
                                   model_type='rnn')

        model = MultipleOutputHead(
            model=encoder,
            output_spec={
                'node': 1,  # to be used to select next node to expand
                'action': codec.feature_len()
            },  # to select the action for chosen node
            drop_rate=drop_rate)

        # don't support using this model in VAE-style models yet
        model.init_encoder_output = lambda x: None
        mask_gen = HypergraphMaskGenerator(max_len=max_seq_length,
                                           grammar=codec.grammar)
        mask_gen.priors = priors
        # bias=codec.grammar.get_log_frequencies())
        if node_policy is None:
            node_policy = SoftmaxRandomSamplePolicy()
        if rule_policy is None:
            rule_policy = SoftmaxRandomSamplePolicy()
        if 'node' in decoder_type:
            stepper = GraphDecoderWithNodeSelection(model,
                                                    node_policy=node_policy,
                                                    rule_policy=rule_policy)
            env = GraphEnvironment(mask_gen,
                                   reward_fun=reward_fun,
                                   batch_size=batch_size)
            decoder = DecoderWithEnvironmentNew(stepper, env)
        else:

            stepper = GraphDecoder(model=model, mask_gen=mask_gen)
            decoder = to_gpu(
                SimpleDiscreteDecoderWithEnv(stepper,
                                             rule_policy,
                                             task=task,
                                             batch_size=batch_size))
        return decoder, stepper

    else:
        if decoder_type == 'step':
            stepper = SimpleRNNDecoder(z_size=z_size,
                                       hidden_n=decoder_hidden_n,
                                       feature_len=codec.feature_len(),
                                       max_seq_length=max_seq_length,
                                       drop_rate=drop_rate,
                                       use_last_action=False)

        elif decoder_type == 'action':
            stepper = SimpleRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'action_resnet':
            stepper = ResNetRNNDecoder(
                z_size=z_size,  # + feature_len,
                hidden_n=decoder_hidden_n,
                feature_len=codec.feature_len(),
                max_seq_length=max_seq_length,
                drop_rate=drop_rate,
                use_last_action=True)

        elif decoder_type == 'attention':
            stepper = SelfAttentionDecoderStep(num_actions=codec.feature_len(),
                                               max_seq_len=max_seq_length,
                                               drop_rate=drop_rate,
                                               enc_output_size=z_size)
        elif decoder_type == 'random':
            stepper = RandomDecoder(feature_len=codec.feature_len(),
                                    max_seq_length=max_seq_length)
        else:
            raise NotImplementedError('Unknown decoder type: ' +
                                      str(decoder_type))

    if grammar is not False and '_graph' not in decoder_type:
        # add a masking layer
        mask_gen = get_codec(molecules, grammar, max_seq_length).mask_gen
        stepper = MaskingHead(stepper, mask_gen)

    policy = SoftmaxRandomSamplePolicy(
    )  # bias=codec.grammar.get_log_frequencies())

    decoder = to_gpu(
        SimpleDiscreteDecoderWithEnv(
            stepper, policy, task=task,
            batch_size=batch_size))  # , bypass_actions=True))

    return decoder, stepper
Beispiel #21
0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file=None,
                          reward_sm=0.0,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'
    gen_save_path = root_location + 'pretrained/gen_' + save_file
    disc_save_path = root_location + 'pretrained/disc_' + save_file

    if smiles_save_file is not None:
        smiles_save_path = root_location + 'pretrained/' + smiles_save_file
        save_dataset = IncrementingHDF5Dataset(smiles_save_path)
    else:
        save_dataset = None

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)

    zinc_data = get_zinc_smiles()
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        tmp = -x  #(
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def adj_reward(x):
        p = discriminator_reward_mult(x)
        reward = np.maximum(reward_fun_on(x), 0)
        out = reward * originality_mult(x) + 2 * p
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        decoder_type=decoder_type,
                        task=task)[0]

    # TODO: really ugly, refactor! In fact this model doesn't need a MaskingHead at all!
    model.stepper.model.mask_gen.priors = True  #'conditional' # use empirical priors for the mask gen
    # if preload_file is not None:
    #     try:
    #         preload_path = root_location + 'pretrained/' + preload_file
    #         model.load_state_dict(torch.load(preload_path))
    #     except:
    #         pass

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        # TODO: rephrase this to return a dict, instead of calling visdom directly
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        # actions, logits, rewards, terminals, info = model_out
        smiles, valid = model_out['info']
        total_rewards = model_out['rewards'].sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'tmp.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except Exception as e:
                print(e)
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)] +
                     [total_rewards[best_ind].item()]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings', 'eff_reward'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=0.9)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    history_callbacks = [make_callback(d) for d in history_data]
    fitter1 = get_rl_fitter(model,
                            PolicyGradientLoss(on_policy_loss_type,
                                               last_reward_wgt=reward_sm),
                            GeneratorToIterable(my_gen),
                            gen_save_path,
                            plot_prefix + 'on-policy',
                            model_process_fun=model_process_fun,
                            lr=lr_on,
                            extra_callbacks=history_callbacks,
                            anchor_model=anchor_model,
                            anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)
    celoss = nn.CrossEntropyLoss()

    def my_loss(x):
        # tmp = discriminator_reward_mult(x['smiles'])
        # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
        # import numpy as np
        # assert np.max(np.abs(tmp-tmp2)) < 1e-6
        return celoss(x['p_zinc'].to(device), x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            my_loss,
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            disc_save_path,
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy(
            )  #bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, on_policy_gen(fitter1, model), fitter2