def main(argv):
    del argv
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    hparams.override_from_dict(
        {'max_steps_per_episode': FLAGS.max_steps_per_episode})

    environment = Molecule(atom_types=set(hparams.atom_types),
                           init_mol=FLAGS.start_molecule,
                           allow_removal=hparams.allow_removal,
                           allow_no_modification=hparams.allow_no_modification,
                           max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#2
0
def get_optimized_mols(model_dir, ckpt=80000):
    """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
    hparams_file = os.path.join(model_dir, 'config.json')
    with gfile.Open(hparams_file, 'r') as f:
        hp_dict = json.load(f)
        hparams = deep_q_networks.get_hparams(**hp_dict)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=0.0)

    tf.reset_default_graph()
    optimized_mol = []
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
        for mol in all_mols:
            logging.info('Eval: %s', mol)
            environment = molecules_mdp.Molecule(
                atom_types=set(hparams.atom_types),
                init_mol=mol,
                allow_removal=hparams.allow_removal,
                allow_no_modification=hparams.allow_no_modification,
                allow_bonds_between_rings=hparams.allow_bonds_between_rings,
                allowed_ring_sizes=set(hparams.allowed_ring_sizes),
                max_steps=hparams.max_steps_per_episode,
                record_path=True)
            environment.initialize()
            if hparams.num_bootstrap_heads:
                head = np.random.randint(hparams.num_bootstrap_heads)
            else:
                head = 0
            for _ in range(hparams.max_steps_per_episode):
                steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
                valid_actions = list(environment.get_valid_actions())
                observations = np.vstack([
                    np.append(deep_q_networks.get_fingerprint(act, hparams),
                              steps_left) for act in valid_actions
                ])
                action = valid_actions[dqn.get_action(observations,
                                                      head=head,
                                                      update_epsilon=0.0)]
                environment.step(action)
            optimized_mol.append(environment.get_path())
    return optimized_mol
 def test_get_fingerprint(self):
     hparams = deep_q_networks.get_hparams(fingerprint_length=64)
     fingerprint = deep_q_networks.get_fingerprint('c1ccccc1', hparams)
     self.assertListEqual(fingerprint.tolist(), [
         1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
     ])
 def test_get_fingerprint_with_steps_left(self):
     hparams = deep_q_networks.get_hparams(fingerprint_length=16)
     fingerprint = deep_q_networks.get_fingerprint_with_steps_left(
         'CC', steps_left=9, hparams=hparams)
     self.assertTupleEqual(fingerprint.shape, (17, ))
     self.assertListEqual(fingerprint.tolist(), [
         0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
         0.0, 0.0, 0.0, 9.0
     ])
示例#5
0
def main(argv):
    del argv  # unused.
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()

    filename = 'all_800_mols.json'
    with gfile.Open(filename) as fp:
        all_molecules = json.load(fp)

    environment = LogPRewardWithSimilarityConstraintMolecule(
        similarity_constraint=FLAGS.similarity_constraint,
        discount_factor=hparams.discount_factor,
        all_molecules=all_molecules,
        atom_types=set(hparams.atom_types),
        init_mol=None,
        allow_removal=hparams.allow_removal,
        allow_no_modification=hparams.allow_no_modification,
        allow_bonds_between_rings=hparams.allow_bonds_between_rings,
        allowed_ring_sizes=set(hparams.allowed_ring_sizes),
        max_steps=hparams.max_steps_per_episode)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=1.0)

    run_dqn.run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#6
0
def main(argv):
  del argv
  if FLAGS.hparams is not None:
    with gfile.Open(FLAGS.hparams, 'r') as f:
      hparams = deep_q_networks.get_hparams(**json.load(f))
  else:
    hparams = deep_q_networks.get_hparams()

  hparams.add_hparam('target_qed', FLAGS.target_qed)
  hparams.add_hparam('target_sas', FLAGS.target_sas)

  environment = Molecule(
      atom_types=set(hparams.atom_types),
      init_mol='CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12',
      allow_removal=hparams.allow_removal,
      allow_no_modification=hparams.allow_no_modification,
      allow_bonds_between_rings=False,
      allowed_ring_sizes={3, 4, 5, 6},
      max_steps=hparams.max_steps_per_episode)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=1.0)

  run_dqn.run_training(
      hparams=hparams,
      environment=environment,
      dqn=dqn,
  )

  core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))
示例#7
0
    def test_multi_objective_dqn(self):
        hparams = deep_q_networks.get_hparams(replay_buffer_size=100,
                                              num_episodes=10,
                                              batch_size=10,
                                              update_frequency=1,
                                              save_frequency=1,
                                              dense_layers=[32],
                                              fingerprint_length=128,
                                              num_bootstrap_heads=0,
                                              prioritized=False,
                                              double_q=False,
                                              fingerprint_radius=2)
        hparams_file = os.path.join(self.mount_point, 'config.json')
        core.write_hparams(hparams, hparams_file)

        with flagsaver.flagsaver(model_dir=self.model_dir,
                                 hparams=hparams_file):
            run_dqn.run_dqn(True)
    def test_run(self):
        hparams = deep_q_networks.get_hparams(replay_buffer_size=100,
                                              num_episodes=10,
                                              batch_size=10,
                                              update_frequency=1,
                                              save_frequency=1,
                                              dense_layers=[32],
                                              fingerprint_length=128,
                                              fingerprint_radius=2,
                                              num_bootstrap_heads=12,
                                              prioritized=True,
                                              double_q=True)
        hparams_file = os.path.join(self.mount_point, 'config.json')
        core.write_hparams(hparams, hparams_file)

        with flagsaver.flagsaver(model_dir=self.model_dir,
                                 hparams=hparams_file):
            optimize_qed.main(None)
def run_dqn(multi_objective=False):
    """Run the training of Deep Q Network algorithm.

  Args:
    multi_objective: Boolean. Whether to run the multiobjective DQN.
  """
    if FLAGS.hparams is not None:
        with gfile.Open(FLAGS.hparams, 'r') as f:
            hparams = deep_q_networks.get_hparams(**json.load(f))
    else:
        hparams = deep_q_networks.get_hparams()
    logging.info(
        'HParams:\n%s', '\n'.join([
            '\t%s: %s' % (key, value)
            for key, value in sorted(hparams.values().items())
        ]))

    # TODO(zzp): merge single objective DQN to multi objective DQN.
    if multi_objective:
        environment = MultiObjectiveRewardMolecule(
            target_molecule=FLAGS.target_molecule,
            atom_types=set(hparams.atom_types),
            init_mol=FLAGS.start_molecule,
            allow_removal=hparams.allow_removal,
            allow_no_modification=hparams.allow_no_modification,
            allow_bonds_between_rings=False,
            allowed_ring_sizes={3, 4, 5, 6},
            max_steps=hparams.max_steps_per_episode)

        dqn = deep_q_networks.MultiObjectiveDeepQNetwork(
            objective_weight=np.array([[0.5], [0.5]]),
            input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
            q_fn=functools.partial(deep_q_networks.multi_layer_model,
                                   hparams=hparams),
            optimizer=hparams.optimizer,
            grad_clipping=hparams.grad_clipping,
            num_bootstrap_heads=hparams.num_bootstrap_heads,
            gamma=hparams.gamma,
            epsilon=1.0)
    else:
        environment = TargetWeightMolecule(
            target_weight=FLAGS.target_weight,
            atom_types=set(hparams.atom_types),
            init_mol=FLAGS.start_molecule,
            allow_removal=hparams.allow_removal,
            allow_no_modification=hparams.allow_no_modification,
            allow_bonds_between_rings=hparams.allow_bonds_between_rings,
            allowed_ring_sizes=set(hparams.allowed_ring_sizes),
            max_steps=hparams.max_steps_per_episode)

        dqn = deep_q_networks.DeepQNetwork(
            input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
            q_fn=functools.partial(deep_q_networks.multi_layer_model,
                                   hparams=hparams),
            optimizer=hparams.optimizer,
            grad_clipping=hparams.grad_clipping,
            num_bootstrap_heads=hparams.num_bootstrap_heads,
            gamma=hparams.gamma,
            epsilon=1.0)

    run_training(
        hparams=hparams,
        environment=environment,
        dqn=dqn,
    )

    core.write_hparams(hparams, os.path.join(FLAGS.model_dir, 'config.json'))