def main(argv): del argv # unused. if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks_parent.get_hparams(**json.load(f)) else: hparams = deep_q_networks_parent.get_hparams() environment = BARewardMolecule( discount_factor=hparams.discount_factor, atom_types=set(hparams.atom_types), init_mol= FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks_parent.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial( deep_q_networks_parent.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn_parent.run_training( hparams=hparams, environment=environment, dqn=dqn) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config_sa.json'))
def test_run(self): hparams = deep_q_networks.get_hparams(replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, fingerprint_radius=2, num_bootstrap_heads=12, prioritized=True, double_q=True) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): optimize_qed.main(None)
def test_multi_objective_dqn(self): hparams = deep_q_networks.get_hparams(replay_buffer_size=100, num_episodes=10, batch_size=10, update_frequency=1, save_frequency=1, dense_layers=[32], fingerprint_length=128, num_bootstrap_heads=0, prioritized=False, double_q=False, fingerprint_radius=2) hparams_file = os.path.join(self.mount_point, 'config.json') core.write_hparams(hparams, hparams_file) with flagsaver.flagsaver(model_dir=self.model_dir, hparams=hparams_file): run_dqn.run_dqn(True)
def run_dqn(multi_objective=False): """Run the training of Deep Q Network algorithm. Args: multi_objective: Boolean. Whether to run the multiobjective DQN. """ if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() logging.info( 'HParams:\n%s', '\n'.join([ '\t%s: %s' % (key, value) for key, value in sorted(hparams.values().items()) ])) # TODO(zzp): merge single objective DQN to multi objective DQN. if multi_objective: environment = MultiObjectiveRewardMolecule( target_molecule=FLAGS.target_molecule, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=False, allowed_ring_sizes={3, 4, 5, 6}, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.MultiObjectiveDeepQNetwork( objective_weight=np.array([[FLAGS.similarity_weight], [1 - FLAGS.similarity_weight]]), input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) else: environment = TargetWeightMolecule( target_weight=FLAGS.target_weight, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))