示例#1
0
    def dequeue(ctx):
        """Inserts into and samples from the replay buffer.

    Args:
      ctx: tf.distribute.InputContext.

    Returns:
      A flattened `SampledUnrolls` structures where per-timestep tensors have
      front dimensions [unroll_length, batch_size_per_replica].
    """
        per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size)
        insertion_batch_size = get_replay_insertion_batch_size(
            per_replica=True)

        print_every = tf.cast(
            insertion_batch_size *
            (1 + FLAGS.replay_buffer_min_size // 50 // insertion_batch_size),
            tf.int64)
        log_summary_every = tf.cast(insertion_batch_size * 500, tf.int64)

        while tf.constant(True):
            # Each tensor in 'unrolls' has shape [insertion_batch_size, unroll_length,
            # <field-specific dimensions>].
            unrolls = unroll_queue.dequeue_many(insertion_batch_size)
            # The replay buffer is not threadsafe (and making it thread-safe might
            # slow it down), which is why we insert and sample in a single thread, and
            # use TF Queues for passing data between threads.
            replay_buffer.insert(unrolls, unrolls.priority)
            if tf.equal(replay_buffer.num_inserted % log_summary_every, 0):
                # Unfortunately, there is no tf.summary(log_every_n_sec).
                tf.summary.histogram('initial_priorities', unrolls.priority)
            if replay_buffer.num_inserted >= FLAGS.replay_buffer_min_size:
                break

            if tf.equal(replay_buffer.num_inserted % print_every, 0):
                tf.print(
                    'Waiting for the replay buffer to fill up. '
                    'It currently has', replay_buffer.num_inserted,
                    'elements, waiting for at least',
                    FLAGS.replay_buffer_min_size, 'elements')

        sampled_indices, weights, sampled_unrolls = replay_buffer.sample(
            per_replica_batch_size, priority_exponent)
        sampled_unrolls = sampled_unrolls._replace(
            prev_actions=utils.make_time_major(sampled_unrolls.prev_actions),
            env_outputs=utils.make_time_major(sampled_unrolls.env_outputs),
            agent_outputs=utils.make_time_major(sampled_unrolls.agent_outputs))
        sampled_unrolls = sampled_unrolls._replace(
            env_outputs=encode(sampled_unrolls.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(
            SampledUnrolls(sampled_unrolls, sampled_indices, weights))
示例#2
0
 def dequeue(ctx):
   # Create batch (time major).
   actor_outputs = tf.nest.map_structure(lambda *args: tf.stack(args), *[
       unroll_queue.dequeue()
       for i in range(ctx.get_per_replica_batch_size(FLAGS.batch_size))
   ])
   actor_outputs = actor_outputs._replace(
       prev_actions=utils.make_time_major(actor_outputs.prev_actions),
       env_outputs=utils.make_time_major(actor_outputs.env_outputs),
       agent_outputs=utils.make_time_major(actor_outputs.agent_outputs))
   actor_outputs = actor_outputs._replace(
       env_outputs=encode(actor_outputs.env_outputs))
   # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
   # repack.
   return tf.nest.flatten(actor_outputs)
示例#3
0
    def dequeue(ctx):
        """Inserts into and samples from the replay buffer.

    Args:
      ctx: tf.distribute.InputContext.

    Returns:
      A flattened `Unroll` structures where per-timestep tensors have
      front dimensions [unroll_length, batch_size_per_replica].
    """
        per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size)

        while tf.constant(True):
            # Each tensor in 'unrolls' has shape [insertion_batch_size, unroll_length,
            # <field-specific dimensions>].
            insert_batch_size = get_replay_insertion_batch_size(
                per_replica=True)
            unrolls = unroll_queue.dequeue_many(insert_batch_size)

            # The replay buffer is not threadsafe (and making it thread-safe might
            # slow it down), which is why we insert and sample in a single thread, and
            # use TF Queues for passing data between threads.
            replay_buffer.insert(unrolls,
                                 priorities=tf.ones(insert_batch_size))

            if replay_buffer.num_inserted >= FLAGS.batch_size:
                break

        _, _, sampled_unrolls = replay_buffer.sample(per_replica_batch_size,
                                                     priority_exp=0.)
        sampled_unrolls = sampled_unrolls._replace(
            prev_actions=utils.make_time_major(sampled_unrolls.prev_actions),
            env_outputs=utils.make_time_major(sampled_unrolls.env_outputs),
            agent_actions=utils.make_time_major(sampled_unrolls.agent_actions))
        sampled_unrolls = sampled_unrolls._replace(
            env_outputs=encode(sampled_unrolls.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(sampled_unrolls)
示例#4
0
 def test_nest(self):
     x = (tf.constant([[1, 2], [3, 4]]), tf.constant([[1], [2]]))
     a, b = utils.make_time_major(x)
     self.assertAllEqual(a, tf.constant([[1, 3], [2, 4]]))
     self.assertAllEqual(b, tf.constant([[1, 2]]))
示例#5
0
 def test_uint16(self):
     x = tf.constant([[1, 2], [3, 4]], tf.uint16)
     self.assertAllEqual(utils.make_time_major(x),
                         tf.constant([[1, 3], [2, 4]]))
示例#6
0
 def test_dynamic(self):
     x, = tf.py_function(
         lambda: np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), [],
         [tf.int32])
     self.assertAllEqual(utils.make_time_major(x),
                         tf.constant([[[1, 2], [5, 6]], [[3, 4], [7, 8]]]))
示例#7
0
 def test_static(self):
     x = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
     self.assertAllEqual(utils.make_time_major(x),
                         tf.constant([[[1, 2], [5, 6]], [[3, 4], [7, 8]]]))
示例#8
0
    def inference(env_ids, run_ids, env_outputs, raw_rewards):
        """Agent inference.

    This evaluates the agent policy on the provided environment data (reward,
    done, observation), and store appropriate data to feed the main training
    loop.

    Args:
      env_ids: <int32>[inference_batch_size], the environment task IDs (in range
        [0, num_tasks)).
      run_ids: <int64>[inference_batch_size], the environment run IDs.
        Environment generates a random int64 run id at startup, so this can be
        used to detect the environment jobs that restarted.
      env_outputs: Follows env_output_specs, but with the inference_batch_size
        added as first dimension. These are the actual environment outputs
        (reward, done, observation).
      raw_rewards: <float32>[inference_batch_size], representing the raw reward
        of each step.

    Returns:
      A tensor <int32>[inference_batch_size] with one action for each
        environment.
    """
        # Reset the environments that had their first run or crashed.
        previous_run_ids = env_run_ids.read(env_ids)
        env_run_ids.replace(env_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        envs_needing_reset = tf.gather(env_ids, reset_indices)
        if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
            tf.print('Environments needing reset:', envs_needing_reset)
        env_infos.reset(envs_needing_reset)
        store.reset(
            tf.gather(envs_needing_reset,
                      tf.where(is_training_env(envs_needing_reset))[:, 0]))
        initial_agent_states = agent.initial_state(
            tf.shape(envs_needing_reset)[0])
        first_agent_states.replace(envs_needing_reset, initial_agent_states)
        agent_states.replace(envs_needing_reset, initial_agent_states)
        actions.reset(envs_needing_reset)

        tf.debugging.assert_non_positive(
            tf.cast(env_outputs.abandoned, tf.int32),
            'Abandoned done states are not supported in R2D2.')

        # Update steps and return.
        env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0])
        done_episodes_info = env_infos.read(done_ids)
        info_queue.enqueue_many(
            EpisodeInfo(*(done_episodes_info + (done_ids, ))))
        env_infos.reset(done_ids)
        env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(env_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(env_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args))

                    return agent_inference(input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = tf.cast(
            inference_iteration.assign_add(1) % len(inference_devices),
            tf.int32)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        agent_outputs = agent_outputs._replace(action=apply_epsilon_greedy(
            agent_outputs.action, env_ids, get_num_training_envs(),
            FLAGS.num_eval_envs, FLAGS.eval_epsilon, num_actions))

        # Append the latest outputs to the unroll, only for experience coming from
        # training environments (IDs < num_training_envs), and insert completed
        # unrolls in queue.
        # <int64>[num_training_envs]
        training_indices = tf.where(is_training_env(env_ids))[:, 0]
        training_env_ids = tf.gather(env_ids, training_indices)
        training_prev_actions, training_env_outputs, training_agent_outputs = (
            tf.nest.map_structure(lambda s: tf.gather(s, training_indices),
                                  (prev_actions, env_outputs, agent_outputs)))

        append_to_store = (training_prev_actions, training_env_outputs,
                           training_agent_outputs)
        completed_ids, completed_unrolls = store.append(
            training_env_ids, append_to_store)
        _, unrolled_env_outputs, unrolled_agent_outputs = completed_unrolls
        unrolled_agent_states = first_agent_states.read(completed_ids)

        # Only use the suffix of the unrolls that is actually used for training. The
        # prefix is only used for burn-in of agent state at training time.
        _, agent_outputs_suffix = utils.split_structure(
            utils.make_time_major(unrolled_agent_outputs), FLAGS.burn_in)
        _, env_outputs_suffix = utils.split_structure(
            utils.make_time_major(unrolled_env_outputs), FLAGS.burn_in)
        _, initial_priorities = compute_loss_and_priorities_from_agent_outputs(
            # We don't use the outputs from a separated target network for computing
            # initial priorities.
            agent_outputs_suffix,
            agent_outputs_suffix,
            env_outputs_suffix,
            agent_outputs_suffix,
            gamma=FLAGS.discounting)

        unrolls = Unroll(unrolled_agent_states, initial_priorities,
                         *completed_unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(env_ids, curr_agent_states)
        actions.replace(env_ids, agent_outputs.action)

        # Return environment actions to environments.
        return agent_outputs.action