Exemple #1
0
def get_optimized_mols(model_dir, ckpt=80000):
    """Get optimized Molecules.

  Args:
    model_dir: String. model directory.
    ckpt: the checkpoint to load.

  Returns:
    List of 800 optimized molecules
  """
    hparams_file = os.path.join(model_dir, 'config.json')
    with gfile.Open(hparams_file, 'r') as f:
        hp_dict = json.load(f)
        hparams = deep_q_networks.get_hparams(**hp_dict)

    dqn = deep_q_networks.DeepQNetwork(
        input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
        q_fn=functools.partial(deep_q_networks.multi_layer_model,
                               hparams=hparams),
        optimizer=hparams.optimizer,
        grad_clipping=hparams.grad_clipping,
        num_bootstrap_heads=hparams.num_bootstrap_heads,
        gamma=hparams.gamma,
        epsilon=0.0)

    tf.reset_default_graph()
    optimized_mol = []
    with tf.Session() as sess:
        dqn.build()
        model_saver = tf.Saver(max_to_keep=hparams.max_num_checkpoints)
        model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
        for mol in all_mols:
            logging.info('Eval: %s', mol)
            environment = molecules_mdp.Molecule(
                atom_types=set(hparams.atom_types),
                init_mol=mol,
                allow_removal=hparams.allow_removal,
                allow_no_modification=hparams.allow_no_modification,
                allow_bonds_between_rings=hparams.allow_bonds_between_rings,
                allowed_ring_sizes=set(hparams.allowed_ring_sizes),
                max_steps=hparams.max_steps_per_episode,
                record_path=True)
            environment.initialize()
            if hparams.num_bootstrap_heads:
                head = np.random.randint(hparams.num_bootstrap_heads)
            else:
                head = 0
            for _ in range(hparams.max_steps_per_episode):
                steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
                valid_actions = list(environment.get_valid_actions())
                observations = np.vstack([
                    np.append(deep_q_networks.get_fingerprint(act, hparams),
                              steps_left) for act in valid_actions
                ])
                action = valid_actions[dqn.get_action(observations,
                                                      head=head,
                                                      update_epsilon=0.0)]
                environment.step(action)
            optimized_mol.append(environment.get_path())
    return optimized_mol
 def test_get_fingerprint(self):
     hparams = deep_q_networks.get_hparams(fingerprint_length=64)
     fingerprint = deep_q_networks.get_fingerprint('c1ccccc1', hparams)
     self.assertListEqual(fingerprint.tolist(), [
         1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
     ])
def _step(environment, dqn, memory, episode, hparams, exploration, head):
  """Runs a single step within an episode.

  Args:
    environment: molecules.Molecule; the environment to run on.
    dqn: DeepQNetwork used for estimating rewards.
    memory: ReplayBuffer used to store observations and rewards.
    episode: Integer episode number.
    hparams: HParams.
    exploration: Schedule used for exploration in the environment.
    head: Integer index of the DeepQNetwork head to use.

  Returns:
    molecules.Result object containing the result of the step.
  """
  # Compute the encoding for each valid action from the current state.
  valid_actions = list(environment.get_valid_actions())
  observations = np.vstack([
      np.append(
          deep_q_networks.get_fingerprint(act, hparams),
          environment.num_steps_taken) for act in valid_actions
  ])
  # Select the next action to take.
  action = valid_actions[dqn.get_action(
      observations, head=head, update_epsilon=exploration.value(episode))]
  result = environment.step(action)
  # Compute the encoding for each valid action from the new state.
  action_fingerprints = np.vstack([
      np.append(
          deep_q_networks.get_fingerprint(act, hparams),
          environment.num_steps_taken)
      for act in environment.get_valid_actions()
  ])
  # we store the fingerprint of the action in obs_t so action
  # does not matter here.
  memory.add(
      obs_t=np.append(
          deep_q_networks.get_fingerprint(action, hparams),
          environment.num_steps_taken),
      action=0,
      reward=result.reward,
      obs_tp1=action_fingerprints,
      done=float(result.terminated))
  return result