def dequeue(ctx): """Inserts into and samples from the replay buffer. Args: ctx: tf.distribute.InputContext. Returns: A flattened `SampledUnrolls` structures where per-timestep tensors have front dimensions [unroll_length, batch_size_per_replica]. """ per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size) insertion_batch_size = get_replay_insertion_batch_size( per_replica=True) print_every = tf.cast( insertion_batch_size * (1 + FLAGS.replay_buffer_min_size // 50 // insertion_batch_size), tf.int64) log_summary_every = tf.cast(insertion_batch_size * 500, tf.int64) while tf.constant(True): # Each tensor in 'unrolls' has shape [insertion_batch_size, unroll_length, # <field-specific dimensions>]. unrolls = unroll_queue.dequeue_many(insertion_batch_size) # The replay buffer is not threadsafe (and making it thread-safe might # slow it down), which is why we insert and sample in a single thread, and # use TF Queues for passing data between threads. replay_buffer.insert(unrolls, unrolls.priority) if tf.equal(replay_buffer.num_inserted % log_summary_every, 0): # Unfortunately, there is no tf.summary(log_every_n_sec). tf.summary.histogram('initial_priorities', unrolls.priority) if replay_buffer.num_inserted >= FLAGS.replay_buffer_min_size: break if tf.equal(replay_buffer.num_inserted % print_every, 0): tf.print( 'Waiting for the replay buffer to fill up. ' 'It currently has', replay_buffer.num_inserted, 'elements, waiting for at least', FLAGS.replay_buffer_min_size, 'elements') sampled_indices, weights, sampled_unrolls = replay_buffer.sample( per_replica_batch_size, priority_exponent) sampled_unrolls = sampled_unrolls._replace( prev_actions=utils.make_time_major(sampled_unrolls.prev_actions), env_outputs=utils.make_time_major(sampled_unrolls.env_outputs), agent_outputs=utils.make_time_major(sampled_unrolls.agent_outputs)) sampled_unrolls = sampled_unrolls._replace( env_outputs=encode(sampled_unrolls.env_outputs)) # tf.data.Dataset treats list leafs as tensors, so we need to flatten and # repack. return tf.nest.flatten( SampledUnrolls(sampled_unrolls, sampled_indices, weights))
def dequeue(ctx): # Create batch (time major). actor_outputs = tf.nest.map_structure(lambda *args: tf.stack(args), *[ unroll_queue.dequeue() for i in range(ctx.get_per_replica_batch_size(FLAGS.batch_size)) ]) actor_outputs = actor_outputs._replace( prev_actions=utils.make_time_major(actor_outputs.prev_actions), env_outputs=utils.make_time_major(actor_outputs.env_outputs), agent_outputs=utils.make_time_major(actor_outputs.agent_outputs)) actor_outputs = actor_outputs._replace( env_outputs=encode(actor_outputs.env_outputs)) # tf.data.Dataset treats list leafs as tensors, so we need to flatten and # repack. return tf.nest.flatten(actor_outputs)
def dequeue(ctx): """Inserts into and samples from the replay buffer. Args: ctx: tf.distribute.InputContext. Returns: A flattened `Unroll` structures where per-timestep tensors have front dimensions [unroll_length, batch_size_per_replica]. """ per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size) while tf.constant(True): # Each tensor in 'unrolls' has shape [insertion_batch_size, unroll_length, # <field-specific dimensions>]. insert_batch_size = get_replay_insertion_batch_size( per_replica=True) unrolls = unroll_queue.dequeue_many(insert_batch_size) # The replay buffer is not threadsafe (and making it thread-safe might # slow it down), which is why we insert and sample in a single thread, and # use TF Queues for passing data between threads. replay_buffer.insert(unrolls, priorities=tf.ones(insert_batch_size)) if replay_buffer.num_inserted >= FLAGS.batch_size: break _, _, sampled_unrolls = replay_buffer.sample(per_replica_batch_size, priority_exp=0.) sampled_unrolls = sampled_unrolls._replace( prev_actions=utils.make_time_major(sampled_unrolls.prev_actions), env_outputs=utils.make_time_major(sampled_unrolls.env_outputs), agent_actions=utils.make_time_major(sampled_unrolls.agent_actions)) sampled_unrolls = sampled_unrolls._replace( env_outputs=encode(sampled_unrolls.env_outputs)) # tf.data.Dataset treats list leafs as tensors, so we need to flatten and # repack. return tf.nest.flatten(sampled_unrolls)
def test_nest(self): x = (tf.constant([[1, 2], [3, 4]]), tf.constant([[1], [2]])) a, b = utils.make_time_major(x) self.assertAllEqual(a, tf.constant([[1, 3], [2, 4]])) self.assertAllEqual(b, tf.constant([[1, 2]]))
def test_uint16(self): x = tf.constant([[1, 2], [3, 4]], tf.uint16) self.assertAllEqual(utils.make_time_major(x), tf.constant([[1, 3], [2, 4]]))
def test_dynamic(self): x, = tf.py_function( lambda: np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), [], [tf.int32]) self.assertAllEqual(utils.make_time_major(x), tf.constant([[[1, 2], [5, 6]], [[3, 4], [7, 8]]]))
def test_static(self): x = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) self.assertAllEqual(utils.make_time_major(x), tf.constant([[[1, 2], [5, 6]], [[3, 4], [7, 8]]]))
def inference(env_ids, run_ids, env_outputs, raw_rewards): """Agent inference. This evaluates the agent policy on the provided environment data (reward, done, observation), and store appropriate data to feed the main training loop. Args: env_ids: <int32>[inference_batch_size], the environment task IDs (in range [0, num_tasks)). run_ids: <int64>[inference_batch_size], the environment run IDs. Environment generates a random int64 run id at startup, so this can be used to detect the environment jobs that restarted. env_outputs: Follows env_output_specs, but with the inference_batch_size added as first dimension. These are the actual environment outputs (reward, done, observation). raw_rewards: <float32>[inference_batch_size], representing the raw reward of each step. Returns: A tensor <int32>[inference_batch_size] with one action for each environment. """ # Reset the environments that had their first run or crashed. previous_run_ids = env_run_ids.read(env_ids) env_run_ids.replace(env_ids, run_ids) reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0] envs_needing_reset = tf.gather(env_ids, reset_indices) if tf.not_equal(tf.shape(envs_needing_reset)[0], 0): tf.print('Environments needing reset:', envs_needing_reset) env_infos.reset(envs_needing_reset) store.reset( tf.gather(envs_needing_reset, tf.where(is_training_env(envs_needing_reset))[:, 0])) initial_agent_states = agent.initial_state( tf.shape(envs_needing_reset)[0]) first_agent_states.replace(envs_needing_reset, initial_agent_states) agent_states.replace(envs_needing_reset, initial_agent_states) actions.reset(envs_needing_reset) tf.debugging.assert_non_positive( tf.cast(env_outputs.abandoned, tf.int32), 'Abandoned done states are not supported in R2D2.') # Update steps and return. env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards)) done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0]) done_episodes_info = env_infos.read(done_ids) info_queue.enqueue_many( EpisodeInfo(*(done_episodes_info + (done_ids, )))) env_infos.reset(done_ids) env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.)) # Inference. prev_actions = actions.read(env_ids) input_ = encode((prev_actions, env_outputs)) prev_agent_states = agent_states.read(env_ids) def make_inference_fn(inference_device): def device_specific_inference_fn(): with tf.device(inference_device): @tf.function def agent_inference(*args): return agent(*decode(args)) return agent_inference(input_, prev_agent_states) return device_specific_inference_fn # Distribute the inference calls among the inference cores. branch_index = tf.cast( inference_iteration.assign_add(1) % len(inference_devices), tf.int32) agent_outputs, curr_agent_states = tf.switch_case( branch_index, { i: make_inference_fn(inference_device) for i, inference_device in enumerate(inference_devices) }) agent_outputs = agent_outputs._replace(action=apply_epsilon_greedy( agent_outputs.action, env_ids, get_num_training_envs(), FLAGS.num_eval_envs, FLAGS.eval_epsilon, num_actions)) # Append the latest outputs to the unroll, only for experience coming from # training environments (IDs < num_training_envs), and insert completed # unrolls in queue. # <int64>[num_training_envs] training_indices = tf.where(is_training_env(env_ids))[:, 0] training_env_ids = tf.gather(env_ids, training_indices) training_prev_actions, training_env_outputs, training_agent_outputs = ( tf.nest.map_structure(lambda s: tf.gather(s, training_indices), (prev_actions, env_outputs, agent_outputs))) append_to_store = (training_prev_actions, training_env_outputs, training_agent_outputs) completed_ids, completed_unrolls = store.append( training_env_ids, append_to_store) _, unrolled_env_outputs, unrolled_agent_outputs = completed_unrolls unrolled_agent_states = first_agent_states.read(completed_ids) # Only use the suffix of the unrolls that is actually used for training. The # prefix is only used for burn-in of agent state at training time. _, agent_outputs_suffix = utils.split_structure( utils.make_time_major(unrolled_agent_outputs), FLAGS.burn_in) _, env_outputs_suffix = utils.split_structure( utils.make_time_major(unrolled_env_outputs), FLAGS.burn_in) _, initial_priorities = compute_loss_and_priorities_from_agent_outputs( # We don't use the outputs from a separated target network for computing # initial priorities. agent_outputs_suffix, agent_outputs_suffix, env_outputs_suffix, agent_outputs_suffix, gamma=FLAGS.discounting) unrolls = Unroll(unrolled_agent_states, initial_priorities, *completed_unrolls) unroll_queue.enqueue_many(unrolls) first_agent_states.replace(completed_ids, agent_states.read(completed_ids)) # Update current state. agent_states.replace(env_ids, curr_agent_states) actions.replace(env_ids, agent_outputs.action) # Return environment actions to environments. return agent_outputs.action