예제 #1
0
 def test_get_fingerprint(self):
     hparams = deep_q_networks.get_hparams(fingerprint_length=64)
     fingerprint = deep_q_networks.get_fingerprint('c1ccccc1', hparams)
     self.assertListEqual(fingerprint.tolist(), [
         1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
     ])
예제 #2
0
def _step(environment, dqn, memory, episode, hparams, exploration, head):
  """Runs a single step within an episode.

  Args:
    environment: molecules.Molecule; the environment to run on.
    dqn: DeepQNetwork used for estimating rewards.
    memory: ReplayBuffer used to store observations and rewards.
    episode: Integer episode number.
    hparams: HParams.
    exploration: Schedule used for exploration in the environment.
    head: Integer index of the DeepQNetwork head to use.

  Returns:
    molecules.Result object containing the result of the step.
  """
  # Compute the encoding for each valid action from the current state.
  steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
  valid_actions = list(environment.get_valid_actions())
  observations = np.vstack([
      np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left)
      for act in valid_actions
  ])
  action = valid_actions[dqn.get_action(
      observations, head=head, update_epsilon=exploration.value(episode))]
  action_t_fingerprint = np.append(
      deep_q_networks.get_fingerprint(action, hparams), steps_left)
  result = environment.step(action)
  steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
  action_fingerprints = np.vstack([
      np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left)
      for act in environment.get_valid_actions()
  ])
  # we store the fingerprint of the action in obs_t so action
  # does not matter here.
  memory.add(
      obs_t=action_t_fingerprint,
      action=0,
      reward=result.reward,
      obs_tp1=action_fingerprints,
      done=float(result.terminated))
  return result