def test_scoring_functions(self): smiles = ['O', 'cccccc'] for ver in version_name_list: print(ver) objectives = guacamol_goal_scoring_functions(ver) for obj in objectives: out = obj(smiles) print(obj.name, out) assert not np.isnan(sum(out)), "Objective returned NaN value!"
def run_initial_scan(num_batches=100, batch_size=30, snapshot_dir=None, entropy_wgt=0.0, root_name=None, obj_num=None, ver='v2', lr=0.01, attempt='', plot=False): grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle' # 'hyper_grammar.pickle' grammar = 'hypergraph:' + grammar_cache reward_funs = guacamol_goal_scoring_functions(ver) reward_fun = reward_funs[obj_num] first_runner = lambda: PolicyGradientRunner( grammar, BATCH_SIZE=batch_size, reward_fun=reward_fun, max_steps=60, num_batches=num_batches, lr=lr, entropy_wgt=entropy_wgt, # lr_schedule=shifted_cosine_schedule, root_name=root_name, preload_file_root_name=None, plot_metrics=plot, save_location=snapshot_dir, metric_smooth=0.0, decoder_type='graph_conditional', # 'rnn_graph',# 'attention', on_policy_loss_type='advantage_record', rule_temperature_schedule=None, # lambda x: toothy_exp_schedule(x, scale=num_batches), eps=0.0, priors='conditional', ) run = 0 while True: model = first_runner() orig_name = model.root_name model.set_root_name(generate_root_name(orig_name, {})) model.run()
from generative_playground.molecules.rdkit_utils.rdkit_utils import num_atoms, num_aromatic_rings, NormalizedScorer # from generative_playground.models.problem.rl.DeepRL_wrappers import BodyAdapter, MyA2CAgent from generative_playground.molecules.model_settings import get_settings from generative_playground.molecules.train.pg.hypergraph.main_train_policy_gradient_minimal import train_policy_gradient from generative_playground.codec.hypergraph_grammar import GrammarInitializer from generative_playground.molecules.guacamol_utils import guacamol_goal_scoring_functions, version_name_list batch_size = 15 # 20 drop_rate = 0.5 molecules = True grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle' #'hyper_grammar.pickle' grammar = 'hypergraph:' + grammar_cache settings = get_settings(molecules, grammar) ver = 'trivial' obj_num = 0 reward_funs = guacamol_goal_scoring_functions(ver) reward_fun = reward_funs[obj_num] # # later will run this ahead of time # gi = GrammarInitializer(grammar_cache) root_name = 'canned_' + ver + '_' + str(obj_num) + 'do 0.5 lr4e-5' max_steps = 45 model, gen_fitter, disc_fitter = train_policy_gradient( molecules, grammar, EPOCHS=100, BATCH_SIZE=batch_size, reward_fun_on=reward_fun, max_steps=max_steps, lr_on=4e-5, lr_discrim=0.0,
def run_mcts( num_batches=10, # respawn after that - workaround for memory leak batch_size=20, ver='trivial', # 'v2'#' obj_num=4, grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle', # 'hyper_grammar.pickle' max_seq_length=60, root_name='', compress_data_store=True, kind='thompson_local', reset_cache=False, penalize_repetition=True, save_every=10, num_bins=50, # TODO: replace with a Value Distribution object updates_to_refresh=10, decay=0.95, lr=0.05, grad_clip=5, entropy_weight=3, ): if penalize_repetition: zinc_set = set(get_smiles_from_database(source='ChEMBL:train')) lookbacks = [batch_size, 10 * batch_size, 100 * batch_size] pre_reward_fun = guacamol_goal_scoring_functions(ver)[obj_num] reward_fun_ = AdjustedRewardCalculator(pre_reward_fun, zinc_set, lookbacks) # reward_fun_ = CountRewardAdjuster(pre_reward_fun) else: reward_fun_ = lambda x: guacamol_goal_scoring_functions(ver)[obj_num ]([x])[0] # load or create the global variables needed here = os.path.dirname(__file__) save_path = os.path.realpath( here + '../../../../molecules/train/mcts/data/') + '/' globals_name = os.path.realpath(save_path + root_name + '.gpkl') db_path = '/' + os.path.realpath(save_path + root_name + '.db').replace( '\\', '/') plotter = MetricPlotter(plot_prefix='', save_file=None, loss_display_cap=4, dashboard_name=root_name, plot_ignore_initial=0, process_model_fun=model_process_fun, extra_metric_fun=None, smooth_weight=0.5, frequent_calls=False) # if 'thompson' in kind: # my_globals = get_thompson_globals(num_bins=num_bins, # reward_fun_=reward_fun_, # grammar_cache=grammar_cache, # max_seq_length=max_seq_length, # decay=decay, # updates_to_refresh=updates_to_refresh, # plotter=plotter # ) # elif 'model' in kind: my_globals = GlobalParametersModel( batch_size=batch_size, reward_fun_=reward_fun_, grammar_cache=grammar_cache, # 'hyper_grammar.pickle' max_depth=max_seq_length, lr=lr, grad_clip=grad_clip, entropy_weight=entropy_weight, decay=decay, num_bins=num_bins, updates_to_refresh=updates_to_refresh, plotter=plotter, degenerate=True if kind == 'model_thompson' else False) if reset_cache: try: os.remove(globals_name) print('removed globals cache ' + globals_name) except: print("Could not remove globals cache" + globals_name) try: os.remove(db_path[1:]) print('removed locals cache ' + db_path[1:]) except: print("Could not remove locals cache" + db_path[1:]) else: try: with gzip.open(globals_name) as f: global_state = dill.load(f) my_globals.set_mutable_state(global_state) print("Loaded global state cache!") except: pass node_type = class_from_kind(kind) from generative_playground.utils.deep_getsizeof import memory_by_type with Shelve(db_path, 'kv_table', compress=compress_data_store) as state_store: my_globals.state_store = state_store root_node = node_type(my_globals, parent=None, source_action=None, depth=1) for b in range(num_batches): mem = memory_by_type() print("memory pre-explore", sum([x[2] for x in mem]), mem[:5]) rewards, infos = explore(root_node, batch_size) state_store.flush() mem = memory_by_type() print("memory post-explore", sum([x[2] for x in mem]), mem[:5]) if b % save_every == save_every - 1: print('**** saving global state ****') with gzip.open(globals_name, 'wb') as f: dill.dump(my_globals.get_mutable_state(), f) # visualisation code goes here plotter_input = { 'rewards': np.array(rewards), 'info': [[x['smiles'] for x in infos], np.ones(len(infos))] } my_globals.plotter(None, None, plotter_input, None, None) print(max(rewards)) # print(root_node.result_repo.avg_reward()) # temp_dir.cleanup() print("done!")
def run_genetic_opt( top_N=10, p_mutate=0.2, mutate_num_best=64, mutate_use_total_probs=False, p_crossover=0.2, num_batches=100, batch_size=30, snapshot_dir=None, entropy_wgt=0.0, root_name=None, obj_num=None, ver='v2', lr=0.01, num_runs=100, num_explore=5, plot_single_runs=True, steps_with_no_improvement=10, reward_aggregation=np.median, attempt='', # only used for disambiguating plotting max_steps=90, past_runs_graph_file=None): manager = mp.Manager() queue = manager.Queue() relationships = nx.DiGraph() grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle' # 'hyper_grammar.pickle' grammar = 'hypergraph:' + grammar_cache reward_funs = guacamol_goal_scoring_functions(ver) reward_fun = reward_funs[obj_num] split_name = root_name.split('_') split_name[0] += 'Stats' dash_name = '_'.join(split_name) + attempt vis = Dashboard(dash_name, call_every=1) first_runner_factory = lambda: PolicyGradientRunner( grammar, BATCH_SIZE=batch_size, reward_fun=reward_fun, max_steps=max_steps, num_batches=num_batches, lr=lr, entropy_wgt=entropy_wgt, # lr_schedule=shifted_cosine_schedule, root_name=root_name, preload_file_root_name=None, plot_metrics=plot_single_runs, save_location=snapshot_dir, metric_smooth=0.0, decoder_type='graph_conditional_sparse', # 'graph_conditional', # 'rnn_graph',# 'attention', on_policy_loss_type='advantage_record', rule_temperature_schedule=None, # lambda x: toothy_exp_schedule(x, scale=num_batches), eps=0.0, priors='conditional', ) init_thresh = 50 pca_dim = 10 if past_runs_graph_file: params, rewards = extract_params_rewards(past_runs_graph_file) sampler = ParameterSampler(params, rewards, init_thresh=init_thresh, pca_dim=pca_dim) else: sampler = None data_cache = {} best_so_far = float('-inf') steps_since_best = 0 initial = True should_stop = False run = 0 with mp.Pool(4) as p: while not should_stop: data_cache = populate_data_cache(snapshot_dir, data_cache) if run < num_explore: model = first_runner_factory() if sampler: model.params = sampler.sample() else: model = (pick_model_to_run(data_cache, PolicyGradientRunner, snapshot_dir, num_best=top_N) if data_cache else first_runner_factory()) orig_name = model.root_name model.set_root_name(generate_root_name(orig_name, data_cache)) if run > num_explore: relationships.add_edge(orig_name, model.root_name) if random.random() < p_crossover and len(data_cache) > 1: second_model = pick_model_for_crossover( data_cache, model, PolicyGradientRunner, snapshot_dir) model = classic_crossover(model, second_model) relationships.add_edge(second_model.root_name, model.root_name) if random.random() < p_mutate: model = mutate(model, pick_best=mutate_num_best, total_probs=mutate_use_total_probs) relationships.node[model.root_name]['mutated'] = True else: relationships.node[model.root_name]['mutated'] = False with open( snapshot_dir + '/' + model.root_name + '_lineage.pkl', 'wb') as f: pickle.dump(relationships, f) model.save() if initial is True: for _ in range(4): print('Starting {}'.format(run)) p.apply_async(run_model, (queue, model.root_name, run, snapshot_dir)) run += 1 initial = False else: print('Starting {}'.format(run)) p.apply_async(run_model, (queue, model.root_name, run, snapshot_dir)) run += 1 finished_run, finished_root_name = queue.get(block=True) print('Finished: {}'.format(finished_root_name)) data_cache = populate_data_cache(snapshot_dir, data_cache) my_rewards = data_cache[finished_root_name]['best_rewards'] metrics = { 'max': my_rewards.max(), 'median': np.median(my_rewards), 'min': my_rewards.min() } metric_dict = { 'type': 'line', 'X': np.array([finished_run]), 'Y': np.array([[val for key, val in metrics.items()]]), 'opts': { 'legend': [key for key, val in metrics.items()] } } vis.plot_metric_dict({'worker rewards': metric_dict}) this_agg_reward = reward_aggregation(my_rewards) if this_agg_reward > best_so_far: best_so_far = this_agg_reward steps_since_best = 0 else: steps_since_best += 1 should_stop = ( steps_since_best >= steps_with_no_improvement and finished_run > num_explore + steps_with_no_improvement) p.terminate() return extract_best(data_cache, 1)