Ejemplo n.º 1
0
 def test_scoring_functions(self):
     smiles = ['O', 'cccccc']
     for ver in version_name_list:
         print(ver)
         objectives = guacamol_goal_scoring_functions(ver)
         for obj in objectives:
             out = obj(smiles)
             print(obj.name, out)
             assert not np.isnan(sum(out)), "Objective returned NaN value!"
def run_initial_scan(num_batches=100,
                     batch_size=30,
                     snapshot_dir=None,
                     entropy_wgt=0.0,
                     root_name=None,
                     obj_num=None,
                     ver='v2',
                     lr=0.01,
                     attempt='',
                     plot=False):
    grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  # 'hyper_grammar.pickle'
    grammar = 'hypergraph:' + grammar_cache
    reward_funs = guacamol_goal_scoring_functions(ver)
    reward_fun = reward_funs[obj_num]

    first_runner = lambda: PolicyGradientRunner(
        grammar,
        BATCH_SIZE=batch_size,
        reward_fun=reward_fun,
        max_steps=60,
        num_batches=num_batches,
        lr=lr,
        entropy_wgt=entropy_wgt,
        # lr_schedule=shifted_cosine_schedule,
        root_name=root_name,
        preload_file_root_name=None,
        plot_metrics=plot,
        save_location=snapshot_dir,
        metric_smooth=0.0,
        decoder_type='graph_conditional',  # 'rnn_graph',# 'attention',
        on_policy_loss_type='advantage_record',
        rule_temperature_schedule=None,
        # lambda x: toothy_exp_schedule(x, scale=num_batches),
        eps=0.0,
        priors='conditional',
    )

    run = 0
    while True:
        model = first_runner()
        orig_name = model.root_name
        model.set_root_name(generate_root_name(orig_name, {}))
        model.run()
Ejemplo n.º 3
0
from generative_playground.molecules.rdkit_utils.rdkit_utils import num_atoms, num_aromatic_rings, NormalizedScorer
# from generative_playground.models.problem.rl.DeepRL_wrappers import BodyAdapter, MyA2CAgent
from generative_playground.molecules.model_settings import get_settings
from generative_playground.molecules.train.pg.hypergraph.main_train_policy_gradient_minimal import train_policy_gradient
from generative_playground.codec.hypergraph_grammar import GrammarInitializer
from generative_playground.molecules.guacamol_utils import guacamol_goal_scoring_functions, version_name_list

batch_size = 15  # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  #'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
ver = 'trivial'
obj_num = 0
reward_funs = guacamol_goal_scoring_functions(ver)
reward_fun = reward_funs[obj_num]
# # later will run this ahead of time
# gi = GrammarInitializer(grammar_cache)

root_name = 'canned_' + ver + '_' + str(obj_num) + 'do 0.5 lr4e-5'
max_steps = 45
model, gen_fitter, disc_fitter = train_policy_gradient(
    molecules,
    grammar,
    EPOCHS=100,
    BATCH_SIZE=batch_size,
    reward_fun_on=reward_fun,
    max_steps=max_steps,
    lr_on=4e-5,
    lr_discrim=0.0,
Ejemplo n.º 4
0
def run_mcts(
    num_batches=10,  # respawn after that - workaround for memory leak
    batch_size=20,
    ver='trivial',  # 'v2'#'
    obj_num=4,
    grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle',  # 'hyper_grammar.pickle'
    max_seq_length=60,
    root_name='',
    compress_data_store=True,
    kind='thompson_local',
    reset_cache=False,
    penalize_repetition=True,
    save_every=10,
    num_bins=50,  # TODO: replace with a Value Distribution object
    updates_to_refresh=10,
    decay=0.95,
    lr=0.05,
    grad_clip=5,
    entropy_weight=3,
):

    if penalize_repetition:
        zinc_set = set(get_smiles_from_database(source='ChEMBL:train'))
        lookbacks = [batch_size, 10 * batch_size, 100 * batch_size]
        pre_reward_fun = guacamol_goal_scoring_functions(ver)[obj_num]
        reward_fun_ = AdjustedRewardCalculator(pre_reward_fun, zinc_set,
                                               lookbacks)
        # reward_fun_ = CountRewardAdjuster(pre_reward_fun)
    else:
        reward_fun_ = lambda x: guacamol_goal_scoring_functions(ver)[obj_num
                                                                     ]([x])[0]

    # load or create the global variables needed
    here = os.path.dirname(__file__)
    save_path = os.path.realpath(
        here + '../../../../molecules/train/mcts/data/') + '/'
    globals_name = os.path.realpath(save_path + root_name + '.gpkl')
    db_path = '/' + os.path.realpath(save_path + root_name + '.db').replace(
        '\\', '/')

    plotter = MetricPlotter(plot_prefix='',
                            save_file=None,
                            loss_display_cap=4,
                            dashboard_name=root_name,
                            plot_ignore_initial=0,
                            process_model_fun=model_process_fun,
                            extra_metric_fun=None,
                            smooth_weight=0.5,
                            frequent_calls=False)

    # if 'thompson' in kind:
    #     my_globals = get_thompson_globals(num_bins=num_bins,
    #                                       reward_fun_=reward_fun_,
    #                                       grammar_cache=grammar_cache,
    #                                       max_seq_length=max_seq_length,
    #                                       decay=decay,
    #                                       updates_to_refresh=updates_to_refresh,
    #                                       plotter=plotter
    #                                       )
    # elif 'model' in kind:
    my_globals = GlobalParametersModel(
        batch_size=batch_size,
        reward_fun_=reward_fun_,
        grammar_cache=grammar_cache,  # 'hyper_grammar.pickle'
        max_depth=max_seq_length,
        lr=lr,
        grad_clip=grad_clip,
        entropy_weight=entropy_weight,
        decay=decay,
        num_bins=num_bins,
        updates_to_refresh=updates_to_refresh,
        plotter=plotter,
        degenerate=True if kind == 'model_thompson' else False)
    if reset_cache:
        try:
            os.remove(globals_name)
            print('removed globals cache ' + globals_name)
        except:
            print("Could not remove globals cache" + globals_name)
        try:
            os.remove(db_path[1:])
            print('removed locals cache ' + db_path[1:])
        except:
            print("Could not remove locals cache" + db_path[1:])
    else:
        try:
            with gzip.open(globals_name) as f:
                global_state = dill.load(f)
                my_globals.set_mutable_state(global_state)
                print("Loaded global state cache!")
        except:
            pass

    node_type = class_from_kind(kind)
    from generative_playground.utils.deep_getsizeof import memory_by_type
    with Shelve(db_path, 'kv_table',
                compress=compress_data_store) as state_store:
        my_globals.state_store = state_store
        root_node = node_type(my_globals,
                              parent=None,
                              source_action=None,
                              depth=1)

        for b in range(num_batches):
            mem = memory_by_type()
            print("memory pre-explore", sum([x[2] for x in mem]), mem[:5])
            rewards, infos = explore(root_node, batch_size)
            state_store.flush()
            mem = memory_by_type()
            print("memory post-explore", sum([x[2] for x in mem]), mem[:5])
            if b % save_every == save_every - 1:
                print('**** saving global state ****')
                with gzip.open(globals_name, 'wb') as f:
                    dill.dump(my_globals.get_mutable_state(), f)

            # visualisation code goes here
            plotter_input = {
                'rewards': np.array(rewards),
                'info': [[x['smiles'] for x in infos],
                         np.ones(len(infos))]
            }
            my_globals.plotter(None, None, plotter_input, None, None)
            print(max(rewards))

    # print(root_node.result_repo.avg_reward())
    # temp_dir.cleanup()
    print("done!")
def run_genetic_opt(
        top_N=10,
        p_mutate=0.2,
        mutate_num_best=64,
        mutate_use_total_probs=False,
        p_crossover=0.2,
        num_batches=100,
        batch_size=30,
        snapshot_dir=None,
        entropy_wgt=0.0,
        root_name=None,
        obj_num=None,
        ver='v2',
        lr=0.01,
        num_runs=100,
        num_explore=5,
        plot_single_runs=True,
        steps_with_no_improvement=10,
        reward_aggregation=np.median,
        attempt='',  # only used for disambiguating plotting
        max_steps=90,
        past_runs_graph_file=None):

    manager = mp.Manager()
    queue = manager.Queue()

    relationships = nx.DiGraph()
    grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  # 'hyper_grammar.pickle'
    grammar = 'hypergraph:' + grammar_cache

    reward_funs = guacamol_goal_scoring_functions(ver)
    reward_fun = reward_funs[obj_num]

    split_name = root_name.split('_')
    split_name[0] += 'Stats'
    dash_name = '_'.join(split_name) + attempt
    vis = Dashboard(dash_name, call_every=1)

    first_runner_factory = lambda: PolicyGradientRunner(
        grammar,
        BATCH_SIZE=batch_size,
        reward_fun=reward_fun,
        max_steps=max_steps,
        num_batches=num_batches,
        lr=lr,
        entropy_wgt=entropy_wgt,
        # lr_schedule=shifted_cosine_schedule,
        root_name=root_name,
        preload_file_root_name=None,
        plot_metrics=plot_single_runs,
        save_location=snapshot_dir,
        metric_smooth=0.0,
        decoder_type='graph_conditional_sparse',
        # 'graph_conditional',  # 'rnn_graph',# 'attention',
        on_policy_loss_type='advantage_record',
        rule_temperature_schedule=None,
        # lambda x: toothy_exp_schedule(x, scale=num_batches),
        eps=0.0,
        priors='conditional',
    )

    init_thresh = 50
    pca_dim = 10
    if past_runs_graph_file:
        params, rewards = extract_params_rewards(past_runs_graph_file)
        sampler = ParameterSampler(params,
                                   rewards,
                                   init_thresh=init_thresh,
                                   pca_dim=pca_dim)
    else:
        sampler = None
    data_cache = {}
    best_so_far = float('-inf')
    steps_since_best = 0

    initial = True
    should_stop = False
    run = 0

    with mp.Pool(4) as p:
        while not should_stop:
            data_cache = populate_data_cache(snapshot_dir, data_cache)
            if run < num_explore:
                model = first_runner_factory()
                if sampler:
                    model.params = sampler.sample()
            else:
                model = (pick_model_to_run(data_cache,
                                           PolicyGradientRunner,
                                           snapshot_dir,
                                           num_best=top_N)
                         if data_cache else first_runner_factory())

            orig_name = model.root_name
            model.set_root_name(generate_root_name(orig_name, data_cache))

            if run > num_explore:
                relationships.add_edge(orig_name, model.root_name)

                if random.random() < p_crossover and len(data_cache) > 1:
                    second_model = pick_model_for_crossover(
                        data_cache, model, PolicyGradientRunner, snapshot_dir)
                    model = classic_crossover(model, second_model)
                    relationships.add_edge(second_model.root_name,
                                           model.root_name)

                if random.random() < p_mutate:
                    model = mutate(model,
                                   pick_best=mutate_num_best,
                                   total_probs=mutate_use_total_probs)
                    relationships.node[model.root_name]['mutated'] = True
                else:
                    relationships.node[model.root_name]['mutated'] = False

                with open(
                        snapshot_dir + '/' + model.root_name + '_lineage.pkl',
                        'wb') as f:
                    pickle.dump(relationships, f)

            model.save()

            if initial is True:
                for _ in range(4):
                    print('Starting {}'.format(run))
                    p.apply_async(run_model,
                                  (queue, model.root_name, run, snapshot_dir))
                    run += 1
                initial = False
            else:
                print('Starting {}'.format(run))
                p.apply_async(run_model,
                              (queue, model.root_name, run, snapshot_dir))
                run += 1

            finished_run, finished_root_name = queue.get(block=True)
            print('Finished: {}'.format(finished_root_name))

            data_cache = populate_data_cache(snapshot_dir, data_cache)
            my_rewards = data_cache[finished_root_name]['best_rewards']
            metrics = {
                'max': my_rewards.max(),
                'median': np.median(my_rewards),
                'min': my_rewards.min()
            }
            metric_dict = {
                'type': 'line',
                'X': np.array([finished_run]),
                'Y': np.array([[val for key, val in metrics.items()]]),
                'opts': {
                    'legend': [key for key, val in metrics.items()]
                }
            }

            vis.plot_metric_dict({'worker rewards': metric_dict})

            this_agg_reward = reward_aggregation(my_rewards)
            if this_agg_reward > best_so_far:
                best_so_far = this_agg_reward
                steps_since_best = 0
            else:
                steps_since_best += 1

            should_stop = (
                steps_since_best >= steps_with_no_improvement
                and finished_run > num_explore + steps_with_no_improvement)

        p.terminate()

    return extract_best(data_cache, 1)