def main(argv): del argv if FLAGS.hparams is not None: with open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() environment = Molecule( atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def main(argv): del argv # unused. if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() filename = 'all_800_mols.json' with gfile.Open(filename) as fp: all_molecules = json.load(fp) environment = LogPRewardWithSimilarityConstraintMolecule( similarity_constraint=FLAGS.similarity_constraint, discount_factor=hparams.discount_factor, all_molecules=all_molecules, atom_types=set(hparams.atom_types), init_mol=None, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_dqn.run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def run_dqn(multi_objective=False): """Run the training of Deep Q Network algorithm. Args: multi_objective: Boolean. Whether to run the multiobjective DQN. """ if FLAGS.hparams is not None: with gfile.Open(FLAGS.hparams, 'r') as f: hparams = deep_q_networks.get_hparams(**json.load(f)) else: hparams = deep_q_networks.get_hparams() logging.info( 'HParams:\n%s', '\n'.join([ '\t%s: %s' % (key, value) for key, value in sorted(hparams.values().items()) ])) # TODO(zzp): merge single objective DQN to multi objective DQN. if multi_objective: environment = MultiObjectiveRewardMolecule( target_molecule=FLAGS.target_molecule, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=False, allowed_ring_sizes={3, 4, 5, 6}, max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.MultiObjectiveDeepQNetwork( objective_weight=np.array([[FLAGS.similarity_weight], [1 - FLAGS.similarity_weight]]), input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) else: environment = TargetWeightMolecule( target_weight=FLAGS.target_weight, atom_types=set(hparams.atom_types), init_mol=FLAGS.start_molecule, allow_removal=hparams.allow_removal, allow_no_modification=hparams.allow_no_modification, allow_bonds_between_rings=hparams.allow_bonds_between_rings, allowed_ring_sizes=set(hparams.allowed_ring_sizes), max_steps=hparams.max_steps_per_episode) dqn = deep_q_networks.DeepQNetwork( input_shape=(hparams.batch_size, hparams.fingerprint_length + 1), q_fn=functools.partial(deep_q_networks.multi_layer_model, hparams=hparams), optimizer=hparams.optimizer, grad_clipping=hparams.grad_clipping, num_bootstrap_heads=hparams.num_bootstrap_heads, gamma=hparams.gamma, epsilon=1.0) run_training( hparams=hparams, environment=environment, dqn=dqn, ) core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))