def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    hparams.override_from_dict(
        {'max_steps_per_episode': FLAGS.max_steps_per_episode})

    environment = Molecule(atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
Пример #2
0
    def test_multi_objective_dqn(self):
        hparams = deep_q_networks.get_hparams(replay_buffer_size=100,
                                              num_episodes=10,
                                              batch_size=10,
                                              update_frequency=1,
                                              save_frequency=1,
                                              dense_layers=[32],
                                              fingerprint_length=128,
                                              num_bootstrap_heads=0,
                                              prioritized=False,
                                              double_q=False,
                                              fingerprint_radius=2)
        hparams_file = os.path.join(self.mount_point, 'config.json')
        core.write_hparams(hparams, hparams_file)

        with flagsaver.flagsaver(model_dir=self.model_dir,
                                 hparams=hparams_file):
            run_dqn.run_dqn(True)
    def test_run(self):
        hparams = deep_q_networks.get_hparams(replay_buffer_size=100,
                                              num_episodes=10,
                                              batch_size=10,
                                              update_frequency=1,
                                              save_frequency=1,
                                              dense_layers=[32],
                                              fingerprint_length=128,
                                              fingerprint_radius=2,
                                              num_bootstrap_heads=12,
                                              prioritized=True,
                                              double_q=True)
        hparams_file = os.path.join(self.mount_point, 'config.json')
        core.write_hparams(hparams, hparams_file)

        with flagsaver.flagsaver(model_dir=self.model_dir,
                                 hparams=hparams_file):
            optimize_qed.main(None)
Пример #4
0
def main(argv):
    del argv  # unused.
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    filename = 'all_800_mols.json'
    with gfile.Open(filename) as fp:
        all_molecules = json.load(fp)

    environment = LogPRewardWithSimilarityConstraintMolecule(
        similarity_constraint=FLAGS.similarity_constraint,
        discount_factor=hparams.discount_factor,
        all_molecules=all_molecules,
        atom_types=set(hparams.atom_types),
        init_mol=None,
        allow_removal=hparams.allow_removal,
        allow_no_modification=hparams.allow_no_modification,
        allow_bonds_between_rings=hparams.allow_bonds_between_rings,
        allowed_ring_sizes=set(hparams.allowed_ring_sizes),
        max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
Пример #5
0
def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = run_dqn.get_hparams(**json.load(f))
    else:
        hparams = run_dqn.get_hparams()

    environment = Molecule(target_weight=FLAGS.target_weight,
                           atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    if FLAGS.error_type.lower() == 'l2':
        klass = deep_q_networks_noise.DeepQNetworkL2
    else:
        klass = deep_q_networks_noise.DeepQNetwork
    dqn = klass(input_shape=(hparams.batch_size, hparams.fingerprint_length),
                q_fn=functools.partial(deep_q_networks_noise.multi_layer_model,
                                       hparams=hparams),
                optimizer=hparams.optimizer,
                grad_clipping=hparams.grad_clipping,
                num_bootstrap_heads=hparams.num_bootstrap_heads,
                gamma=hparams.gamma,
                epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    hparams.add_hparam('noise_std', FLAGS.noise_std)
    hparams.add_hparam('error_type', FLAGS.error_type)
    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
Пример #6
0
def main(argv):
  del argv
  if FLAGS.hparams is not None:
    with gfile.Open(FLAGS.hparams, 'r') as f:
      hparams = deep_q_networks.get_hparams(**json.load(f))
  else:
    hparams = deep_q_networks.get_hparams()

  hparams.add_hparam('target_qed', FLAGS.target_qed)
  hparams.add_hparam('target_sas', FLAGS.target_sas)

  environment = Molecule(
      atom_types=set(hparams.atom_types),
      init_mol='CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12',
      allow_removal=hparams.allow_removal,
      allow_no_modification=hparams.allow_no_modification,
      allow_bonds_between_rings=False,
      allowed_ring_sizes={3, 4, 5, 6},
      max_steps=hparams.max_steps_per_episode)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=1.0)

  run_dqn.run_training(
      hparams=hparams,
      environment=environment,
      dqn=dqn,
  )

  core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
def run_dqn(multi_objective=False):
    """Run the training of Deep Q Network algorithm.

  Args:
    multi_objective: Boolean. Whether to run the multiobjective DQN.
  """
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()
    logging.info(
        'HParams:\n%s', '\n'.join([
            '\t%s: %s' % (key, value)
            for key, value in sorted(hparams.values().items())
        ]))

    # TODO(zzp): merge single objective DQN to multi objective DQN.
    if multi_objective:
        environment = MultiObjectiveRewardMolecule(
            target_molecule=FLAGS.target_molecule,
            atom_types=set(hparams.atom_types),
            init_mol=FLAGS.start_molecule,
            allow_removal=hparams.allow_removal,
            allow_no_modification=hparams.allow_no_modification,
            allow_bonds_between_rings=False,
            allowed_ring_sizes={3, 4, 5, 6},
            max_steps=hparams.max_steps_per_episode)

        dqn = deep_q_networks.MultiObjectiveDeepQNetwork(
            objective_weight=np.array([[0.5], [0.5]]),
            input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
            q_fn=functools.partial(deep_q_networks.multi_layer_model,
                                   hparams=hparams),
            optimizer=hparams.optimizer,
            grad_clipping=hparams.grad_clipping,
            num_bootstrap_heads=hparams.num_bootstrap_heads,
            gamma=hparams.gamma,
            epsilon=1.0)
    else:
        environment = TargetWeightMolecule(
            target_weight=FLAGS.target_weight,
            atom_types=set(hparams.atom_types),
            init_mol=FLAGS.start_molecule,
            allow_removal=hparams.allow_removal,
            allow_no_modification=hparams.allow_no_modification,
            allow_bonds_between_rings=hparams.allow_bonds_between_rings,
            allowed_ring_sizes=set(hparams.allowed_ring_sizes),
            max_steps=hparams.max_steps_per_episode)

        dqn = deep_q_networks.DeepQNetwork(
            input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
            q_fn=functools.partial(deep_q_networks.multi_layer_model,
                                   hparams=hparams),
            optimizer=hparams.optimizer,
            grad_clipping=hparams.grad_clipping,
            num_bootstrap_heads=hparams.num_bootstrap_heads,
            gamma=hparams.gamma,
            epsilon=1.0)

    run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))