def create_agent(sess, environment, summary_writer=None): """Creates a DQN agent. Args: sess: A `tf.Session` object for running associated ops. environment: An Atari 2600 Gym environment. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ if not FLAGS.debug_mode: summary_writer = None if FLAGS.agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=5, summary_writer=summary_writer) elif FLAGS.agent_name == 'rainbow': return rainbow_agent.RainbowAgent(sess, num_actions=5, summary_writer=summary_writer) elif FLAGS.agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=5, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
def create_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an agent. Args: sess: A `tf.compat.v1.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def testCreateAgentWithDefaults(self): # Verifies that we can create and train an agent with the default values. with tf.Session() as sess: agent = rainbow_agent.RainbowAgent(sess, num_actions=4) sess.run(tf.global_variables_initializer()) observation = np.ones([84, 84, 1]) agent.begin_episode(observation) agent.step(reward=1, observation=observation) agent.end_episode(reward=1)
def testStoreTransitionWithPrioritizedSamplingy(self): with tf.Session() as sess: agent = rainbow_agent.RainbowAgent( sess, num_actions=4, replay_scheme='prioritized') dummy_frame = np.zeros((84, 84)) # Adding transitions with default, 10., default priorities. agent._store_transition(dummy_frame, 0, 0, False) agent._store_transition(dummy_frame, 0, 0, False, 10.) agent._store_transition(dummy_frame, 0, 0, False) returned_priorities = agent._replay.memory.get_priority( np.arange(self.stack_size - 1, self.stack_size + 2, dtype=np.int32)) expected_priorities = [1., 10., 10.] self.assertAllEqual(returned_priorities, expected_priorities)
def __init__(self, **kwargs): super(RainbowWrapper, self).__init__(**kwargs) args, num_inputs, num_outputs, factor = self.get_args(kwargs) self.minmax = default_value_arg(kwargs, 'minmax', None) if args.true_environment and args.state_forms[0] == 'raw': network = atari_lib.rainbow_network print(network) observation_shape = (84, 84) else: network = create_rainbow_network(self.minmax) print("num inputs", num_inputs) observation_shape = (num_inputs, ) self.sess = tf.Session( '', config=tf.ConfigProto(allow_soft_placement=True)) # local_device_protos = device_lib.list_local_devices() # print([x.name for x in local_device_protos]) # print(atari_lib.NATURE_DQN_OBSERVATION_SHAPE, # atari_lib.NATURE_DQN_DTYPE, # atari_lib.NATURE_DQN_STACK_SIZE,) # print(num_inputs) self.dope_rainbow = rainbow_agent.RainbowAgent( self.sess, num_outputs, observation_shape=observation_shape, observation_dtype=tf.float32, stack_size=args.num_stack, network=network, num_atoms=51, vmax=args.value_bounds[1], gamma=args.gamma, update_horizon=1, min_replay_history=20000, update_period=4, target_update_period=8000, epsilon_fn=dqn_agent.linearly_decaying_epsilon, epsilon_train=0.01, epsilon_eval=0.001, epsilon_decay_period=250000, replay_scheme='prioritized', tf_device='/gpu:' + str(args.gpu), use_staging=True, optimizer=tf.train.AdamOptimizer(learning_rate=.00025, epsilon=0.0003125), summary_writer=None, summary_writing_frequency=500) self.dope_rainbow.eval_mode = False self.sess.run(tf.global_variables_initializer())
def create_agent(sess, agent_name, num_actions, observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE, observation_dtype=atari_lib.NATURE_DQN_DTYPE, stack_size=atari_lib.NATURE_DQN_STACK_SIZE, summary_writer=None): """Creates an agent. Args: sess: A `tf.Session` object for running associated ops. agent_name: str, name of the agent to create. num_actions: int, number of actions the agent can take at any state. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. observation_shape: tuple of ints describing the observation shape. observation_dtype: tf.DType, specifies the type of the observations. Note that if your inputs are continuous, you should set this to tf.float32. stack_size: int, number of frames to use in state stack. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list or one of the GAIRL submodules is not in supported list when the chosen agent is GAIRL. """ if agent_name == 'dqn': return dqn_agent.DQNAgent( sess, num_actions, observation_shape=observation_shape, observation_dtype=observation_dtype, stack_size=stack_size, summary_writer=summary_writer ) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions, observation_shape=observation_shape, observation_dtype=observation_dtype, stack_size=stack_size, summary_writer=summary_writer ) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions, summary_writer=summary_writer ) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def create_agent(sess, environment): """Creates a DQN agent. Args: sess: A `tf.Session` object for running associated ops. environment: An Atari 2600 Gym environment. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ if FLAGS.agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n) elif FLAGS.agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n) elif FLAGS.agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n) else: raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
def create_agent_fn(sess, env, summary_writer): return rainbow_agent.RainbowAgent(sess=sess, num_actions=env.action_space.n, summary_writer=summary_writer)
def _create_test_agent(self, sess): stack_size = self.stack_size # This dummy network allows us to deterministically anticipate that # action 0 will be selected by an argmax. # In Rainbow we are dealing with a distribution over Q-values, # which are represented as num_atoms bins, ranging from -vmax to vmax. # The output layer will have num_actions * num_atoms elements, # so each group of num_atoms weights represent the logits for a # particular action. By setting 1s everywhere, except for the first # num_atoms (representing the logits for the first action), which are # set to np.arange(num_atoms), we are ensuring that the first action # places higher weight on higher Q-values; this results in the first # action being chosen. class MockRainbowNetwork(tf.keras.Model): """Custom tf.keras.Model used in tests.""" def __init__(self, num_actions, num_atoms, support, **kwargs): super(MockRainbowNetwork, self).__init__(**kwargs) self.num_actions = num_actions self.num_atoms = num_atoms self.support = support first_row = np.tile(np.ones(self.num_atoms), self.num_actions - 1) first_row = np.concatenate( (np.arange(self.num_atoms), first_row)) bottom_rows = np.tile( np.ones(self.num_actions * self.num_atoms), (stack_size - 1, 1)) weights_initializer = np.concatenate( ([first_row], bottom_rows)) self.layer = tf.keras.layers.Dense( self.num_actions * self.num_atoms, kernel_initializer=tf.constant_initializer( weights_initializer), bias_initializer=tf.ones_initializer()) def call(self, state): inputs = tf.constant(np.zeros((state.shape[0], stack_size)), dtype=tf.float32) net = self.layer(inputs) logits = tf.reshape(net, [-1, self.num_actions, self.num_atoms]) probabilities = tf.keras.activations.softmax(logits) qs = tf.reduce_sum(self.support * probabilities, axis=2) return atari_lib.RainbowNetworkType(qs, logits, probabilities) agent = rainbow_agent.RainbowAgent( sess=sess, network=MockRainbowNetwork, num_actions=self._num_actions, num_atoms=self._num_atoms, vmax=self._vmax, min_replay_history=self._min_replay_history, epsilon_fn=lambda w, x, y, z: 0.0, # No exploration. epsilon_eval=0.0, epsilon_decay_period=self._epsilon_decay_period) # This ensures non-random action choices (since epsilon_eval = 0.0) and # skips the train_step. agent.eval_mode = True sess.run(tf.global_variables_initializer()) return agent