コード例 #1
0
ファイル: train.py プロジェクト: zzhaozeng/google-research
    def create_agent_fn(sess, environment, summary_writer):
        """Creates the appropriate agent."""
        if agent_type == 'dqn':
            return dqn_agent.DQNAgent(sess=sess,
                                      num_actions=environment.action_space.n,
                                      summary_writer=summary_writer)

        if agent_type == 'al_dqn':
            return al_dqn.ALDQNAgent(sess=sess,
                                     num_actions=environment.action_space.n,
                                     alpha=FLAGS.shaping_scale,
                                     persistent=FLAGS.persistent,
                                     summary_writer=summary_writer)

        if agent_type == 'm_dqn':
            return m_dqn.MunchausenDQNAgent(
                sess=sess,
                num_actions=environment.action_space.n,
                tau=FLAGS.tau,
                alpha=FLAGS.alpha,
                clip_value_min=FLAGS.clip_value_min,
                interact=FLAGS.interact,
                summary_writer=summary_writer)

        if agent_type == 'm_iqn':
            return m_iqn.MunchausenIQNAgent(
                sess=sess,
                num_actions=environment.action_space.n,
                tau=FLAGS.tau,
                alpha=FLAGS.alpha,
                interact=FLAGS.interact,
                clip_value_min=FLAGS.clip_value_min,
                summary_writer=summary_writer)

        raise ValueError('Wrong agent %s' % agent_type)
コード例 #2
0
 def create_agent_fn(sess, environment, summary_writer):
   """Creates the appropriate agent."""
   if agent_type == 'dqn':
     return dqn_agent.DQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'iqn':
     return implicit_quantile_agent.ImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'al_dqn':
     return al_dqn.ALDQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'al_iqn':
     return al_iqn.ALImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'sail_dqn':
     return sail_dqn.SAILDQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   elif agent_type == 'sail_iqn':
     return sail_iqn.SAILImplicitQuantileAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         summary_writer=summary_writer)
   else:
     raise ValueError('Wrong agent %s' % agent_type)
コード例 #3
0
def create_agent(sess,
                 environment,
                 base_dir,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  base_dir=base_dir,
                                  summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
コード例 #4
0
 def testCreateAgentWithDefaults(self):
   # Verifies that we can create and train an agent with the default values.
   with tf.Session() as sess:
     agent = dqn_agent.DQNAgent(sess, num_actions=4)
     sess.run(tf.global_variables_initializer())
     observation = np.ones([84, 84, 1])
     agent.begin_episode(observation)
     agent.step(reward=1, observation=observation)
     agent.end_episode(reward=1)
コード例 #5
0
ファイル: run_experiment.py プロジェクト: alhamzah/dopamine
def create_agent(sess, environment, agent_name=None, summary_writer=None,
                 debug_mode=False):
  """Creates an agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
  assert agent_name is not None
  if not debug_mode:
    summary_writer = None
  if agent_name == 'dqn':
    return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n,
                              summary_writer=summary_writer)
  elif agent_name == 'rainbow':
    return rainbow_agent.RainbowAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'implicit_quantile':
    return implicit_quantile_agent.ImplicitQuantileAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'ddqn':
    return ddqn_agent.DDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'multi_head_ddqn':
    return multi_head_ddqn_agent.MultiHeadDDQNAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'multi_head_ucb':
    return multi_head_ucb_agent.MultiHeadUCBAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'multi_head_thompson':
    return multi_head_thompson_agent.MultiHeadThompsonAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  elif agent_name == 'multi_head_contextual_ucb':
    return multi_head_contextual_ucb_agent.MultiHeadContextualUCBAgent(
        sess, num_actions=environment.action_space.n,
        summary_writer=summary_writer)
  else:
    raise ValueError('Unknown agent: {}'.format(agent_name))
コード例 #6
0
    def testBundling(self):
        """Tests that local values are poperly updated when reading a checkpoint."""
        with tf.Session() as sess:
            agent = dqn_agent.DQNAgent(sess, 3, observation_shape=(2, 2))
            sess.run(tf.global_variables_initializer())
            agent.state = 'state_val'
            bundle = agent.bundle_and_checkpoint(self.get_temp_dir(),
                                                 iteration_number=10)
            self.assertIn('state', bundle)
            self.assertEqual(bundle['state'], 'state_val')
            bundle['state'] = 'new_state_val'

            with test_utils.mock_thread('other-thread'):
                agent.unbundle(self.get_temp_dir(),
                               iteration_number=10,
                               bundle_dictionary=bundle)
                self.assertEqual(agent.state, 'new_state_val')
            self.assertEqual(agent.state, 'state_val')
コード例 #7
0
def create_agent(sess, agent_name, num_actions,
                 observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE,
                 observation_dtype=atari_lib.NATURE_DQN_DTYPE,
                 stack_size=atari_lib.NATURE_DQN_STACK_SIZE,
                 summary_writer=None):
  """Creates an agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    agent_name: str, name of the agent to create.
    num_actions: int, number of actions the agent can take at any state.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    observation_shape: tuple of ints describing the observation shape.
    observation_dtype: tf.DType, specifies the type of the observations. Note
      that if your inputs are continuous, you should set this to tf.float32.
    stack_size: int, number of frames to use in state stack.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list or one of the
      GAIRL submodules is not in supported list when the chosen agent is GAIRL.
  """
  if agent_name == 'dqn':
    return dqn_agent.DQNAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'rainbow':
    return rainbow_agent.RainbowAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'implicit_quantile':
    return implicit_quantile_agent.ImplicitQuantileAgent(
      sess, num_actions, summary_writer=summary_writer
    )
  else:
    raise ValueError('Unknown agent: {}'.format(agent_name))
コード例 #8
0
  def _create_test_agent(self, sess, allow_partial_reload=False):
    stack_size = self.stack_size

    # This dummy network allows us to deterministically anticipate that
    # action 0 will be selected by an argmax.
    class MockDQNNetwork(tf.keras.Model):
      """The Keras network used in tests."""

      def __init__(self, num_actions, **kwargs):
        # This weights_initializer gives action 0 a higher weight, ensuring
        # that it gets picked by the argmax.
        super(MockDQNNetwork, self).__init__(**kwargs)
        weights_initializer = np.tile(
            np.arange(num_actions, 0, -1), (stack_size, 1))
        self.layer = tf.keras.layers.Dense(
            num_actions,
            kernel_initializer=tf.constant_initializer(weights_initializer),
            bias_initializer=tf.ones_initializer())

      def call(self, state):
        inputs = tf.constant(
            np.zeros((state.shape[0], stack_size)), dtype=tf.float32)
        return atari_lib.DQNNetworkType(self.layer((inputs)))

    agent = dqn_agent.DQNAgent(
        sess=sess,
        network=MockDQNNetwork,
        observation_shape=self.observation_shape,
        observation_dtype=self.observation_dtype,
        stack_size=self.stack_size,
        num_actions=self.num_actions,
        min_replay_history=self.min_replay_history,
        epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
        update_period=self.update_period,
        target_update_period=self.target_update_period,
        epsilon_eval=0.0,  # No exploration during evaluation.
        allow_partial_reload=allow_partial_reload)
    # This ensures non-random action choices (since epsilon_eval = 0.0) and
    # skips the train_step.
    agent.eval_mode = True
    sess.run(tf.global_variables_initializer())
    return agent
コード例 #9
0
ファイル: run.py プロジェクト: yumbohorquez/bsuite
 def create_agent(sess: tf.Session,
                  environment: gym.Env,
                  summary_writer=None):
     """Factory method for agent initialization in Dopmamine."""
     del summary_writer
     return dqn_agent.DQNAgent(
         sess=sess,
         num_actions=environment.action_space.n,
         observation_shape=OBSERVATION_SHAPE,
         observation_dtype=tf.float32,
         stack_size=1,
         network=Network,
         gamma=FLAGS.agent_discount,
         update_horizon=1,
         min_replay_history=FLAGS.min_replay_size,
         update_period=FLAGS.sgd_period,
         target_update_period=FLAGS.target_update_period,
         epsilon_decay_period=FLAGS.epsilon_decay_period,
         epsilon_train=FLAGS.epsilon,
         optimizer=tf.train.AdamOptimizer(FLAGS.learning_rate),
     )
コード例 #10
0
    def __init__(self, **kwargs):
        super(DQNWrapper, self).__init__(**kwargs)
        args, num_inputs, num_outputs, factor = self.get_args(kwargs)
        self.minmax = default_value_arg(kwargs, 'minmax', None)
        network = create_dqn_network(self.minmax)
        observation_shape = (num_inputs, )
        if args.true_environment:
            network = atari_lib.nature_dqn_network
            observation_shape = (84, 84)

        self.sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        self.dope_dqn = dqn_agent.DQNAgent(
            self.sess,
            num_outputs,
            observation_shape=observation_shape,
            observation_dtype=tf.float32,
            stack_size=args.num_stack,
            network=network,
            gamma=args.gamma,
            update_horizon=3,
            min_replay_history=20000,
            update_period=4,
            target_update_period=8000,
            epsilon_fn=dqn_agent.linearly_decaying_epsilon,
            epsilon_train=args.greedy_epsilon,
            epsilon_eval=0.001,
            epsilon_decay_period=250000,
            tf_device='/gpu:3',
            use_staging=True,
            max_tf_checkpoints_to_keep=4,
            optimizer=tf.train.RMSPropOptimizer(learning_rate=args.lr,
                                                decay=0.95,
                                                momentum=0.0,
                                                epsilon=args.eps,
                                                centered=True),
            summary_writer=None,
            summary_writing_frequency=500)
        init = tf.global_variables_initializer()
        self.sess.run(init)
コード例 #11
0
def create_agent(sess, environment):
    """Creates a DQN agent.

    Args:
      sess: A `tf.Session` object for running associated ops.
      environment: An Atari 2600 Gym environment.

    Returns:
      agent: An RL agent.

    Raises:
      ValueError: If `agent_name` is not in supported list.
    """
    if FLAGS.agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess, num_actions=environment.action_space.n)
    else:
        raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
コード例 #12
0
    def testLocalValues(self):
        """Tests that episode related variables are thread specific."""
        with tf.Session() as sess:
            observation_shape = (2, 2)
            agent = dqn_agent.DQNAgent(sess,
                                       3,
                                       observation_shape=observation_shape)
            sess.run(tf.global_variables_initializer())

            with test_utils.mock_thread('baseline-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                local_values_1 = (agent._observation, agent._last_observation,
                                  agent.state)

            with test_utils.mock_thread('different-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                agent.step(reward=10,
                           observation=np.ones(observation_shape),
                           training=False)
                local_values_3 = (agent._observation, agent._last_observation,
                                  agent.state)

            with test_utils.mock_thread('identical-thread'):
                agent.begin_episode(observation=np.zeros(observation_shape),
                                    training=False)
                local_values_2 = (agent._observation, agent._last_observation,
                                  agent.state)

            # Asserts that values in 'identical-thread' are same as baseline.
            for val_1, val_2 in zip(local_values_1, local_values_2):
                self.assertTrue(np.all(val_1 == val_2))

            # Asserts that values in 'different-thread' are differnt from baseline.
            for val_1, val_3 in zip(local_values_1, local_values_3):
                self.assertTrue(np.any(val_1 != val_3))