def main(argv): del argv if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() hparams.override_from_dict( {'max_steps_per_episode': FLAGS.max_steps_per_episode}) environment = Molecule(atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def get_optimized_mols(model_dir, ckpt=80000): """Get optimized Molecules. Args: model_dir: String. model directory. ckpt: the checkpoint to load. Returns: List of 800 optimized molecules """ hparams_file = os.path.join(model_dir, 'config.json') with gfile.Open(hparams_file, 'r') as f: hp_dict = json.load(f) hparams = deep_q_networks.get_hparams(**hp_dict) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=0.0) tf.reset_default_graph() optimized_mol = [] with tf.Session() as sess: dqn.build() model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints) model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt)) for mol in all_mols: logging.info('Eval: %s', mol) environment = molecules_mdp.Molecule( atom_types=set(hparams.atom_types), init_mol=mol, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode, record_path=True) environment.initialize() if hparams.num_bootstrap_heads: head = np.random.randint(hparams.num_bootstrap_heads) else: head = 0 for _ in range(hparams.max_steps_per_episode): steps_left = hparams.max_steps_per_episode - environment.num_steps_taken valid_actions = list(environment.get_valid_actions()) observations = np.vstack([ np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left) for act in valid_actions ]) action = valid_actions[dqn.get_action(observations, head=head, update_epsilon=0.0)] environment.step(action) optimized_mol.append(environment.get_path()) return optimized_mol
def test_get_fingerprint(self): hparams = deep_q_networks.get_hparams(fingerprint_length=64) fingerprint = deep_q_networks.get_fingerprint('c1ccccc1', hparams) self.assertListEqual(fingerprint.tolist(), [ 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ])
def test_get_fingerprint_with_steps_left(self): hparams = deep_q_networks.get_hparams(fingerprint_length=16) fingerprint = deep_q_networks.get_fingerprint_with_steps_left( 'CC', steps_left=9, hparams=hparams) self.assertTupleEqual(fingerprint.shape, (17, )) self.assertListEqual(fingerprint.tolist(), [ 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 9.0 ])
def main(argv): del argv # unused. if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() filename = 'all_800_mols.json' with gfile.Open(filename) as fp: all_molecules = json.load(fp) environment = LogPRewardWithSimilarityConstraintMolecule( similarity_constraint=FLAGS.similarity_constraint, discount_factor=hparams.discount_factor, all_molecules=all_molecules, atom_types=set(hparams.atom_types), init_mol=None, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def main(argv): del argv if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() hparams.add_hparam('target_qed', FLAGS.target_qed) hparams.add_hparam('target_sas', FLAGS.target_sas) environment = Molecule( atom_types=set(hparams.atom_types), init_mol='CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12', allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=False, allowed_ring_sizes={3, 4, 5, 6}, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial( deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def test_multi_objective_dqn(self): hparams = deep_q_networks.get_hparams(replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, num_bootstrap_heads=0, prioritized=False, double_q=False, fingerprint_radius=2) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): run_dqn.run_dqn(True)
def test_run(self): hparams = deep_q_networks.get_hparams(replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, fingerprint_radius=2, num_bootstrap_heads=12, prioritized=True, double_q=True) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): optimize_qed.main(None)
def run_dqn(multi_objective=False): """Run the training of Deep Q Network algorithm. Args: multi_objective: Boolean. Whether to run the multiobjective DQN. """ if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() logging.info( 'HParams:\n%s', '\n'.join([ '\t%s: %s' % (key, value) for key, value in sorted(hparams.values().items()) ])) # TODO(zzp): merge single objective DQN to multi objective DQN. if multi_objective: environment = MultiObjectiveRewardMolecule( target_molecule=FLAGS.target_molecule, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=False, allowed_ring_sizes={3, 4, 5, 6}, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.MultiObjectiveDeepQNetwork( objective_weight=np.array([[0.5], [0.5]]), input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) else: environment = TargetWeightMolecule( target_weight=FLAGS.target_weight, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))