def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with PredictiveRepresentationLearnerInfo

        Note that the shape of experience is [B, T, ...].

        The target is a Tensor (or a nest of Tensors) when there is only one
        decoder. When there are multiple decorders, the target is a list,
        and each of its element is a Tensor (or a nest of Tensors), which is
        used as the target for the corresponding decoder.

        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        mini_batch_length = experience.step_type.shape[1]

        with alf.device(replay_buffer.device):
            # [B, 1]
            positions = convert_device(batch_info.positions).unsqueeze(-1)
            # [B, 1]
            env_ids = convert_device(batch_info.env_ids).unsqueeze(-1)

            # [B, T]
            positions = positions + torch.arange(mini_batch_length)

            # [B, T]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B, T]
            episode_end_positions = positions + steps_to_episode_end

            # [B, T, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1, 1]
            env_ids = env_ids.unsqueeze(-1)
            # [B, T, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            # [B, T, unroll_steps+1]
            mask = positions <= episode_end_positions

            # [B, T, unroll_steps+1]
            positions = torch.min(positions, episode_end_positions)

            # [B, T, unroll_steps+1, ...]
            target = replay_buffer.get_field(self._target_fields, env_ids,
                                             positions)

            # [B, T, unroll_steps+1]
            action = replay_buffer.get_field('action', env_ids, positions)

            rollout_info = PredictiveRepresentationLearnerInfo(action=action,
                                                               mask=mask,
                                                               target=target)

        rollout_info = convert_device(rollout_info)

        return experience._replace(rollout_info=rollout_info)
Beispiel #2
0
    def find_sum_bound(self, thresholds):
        """
        The result is an int64 Tensor with the same shape as `thresholds`.
        result[i] is the minimum idx such that
            thresholds[i] < values[0] + ... + values[idx]

        values[result[i]] will never be 0.

        Args:
            thresholds (Tensor): 1-D Tensor. All the elements in `thresholds`
                should be smaller than self.summary()
        Returns:
            Tensor: 1-D int64 Tensor with the same shape as ``thresholds``.
                Note that if thresholds[i] == root,  result[i] will be
                the index of the non-zero value with the largest index.
        Raises:
            ValueError:  If one or more of ``thresholds`` is greather than ``summary()``.
        """

        def _step(indices, thresholds):
            """Choose one of the children of each index based on threshold.

            If threshold is greater than or equal to the
            left child, choose the right child and update threhsold to threhsold - left_child.
            Otherwise choose left child and keep threshold unchanged.
            """
            indices *= 2
            left = self._values[indices]
            right = self._values[indices + 1]
            # The condition (thresholds >= left) * (right == 0) is only possible
            # if the original threshold == summary(), we want to make sure we
            # still get an index corresponding to non-zero value.
            greater = (thresholds >= left) * (right > 0)
            indices = torch.where(greater, indices + 1, indices)
            thresholds = torch.where(greater, thresholds - left, thresholds)
            return indices, thresholds

        with alf.device(self._device):
            if not torch.all(thresholds <= self.summary()):
                raise ValueError("thresholds cannot "
                                 "be greater than summary(): got %s vs. %s" %
                                 (thresholds.max(), self.summary()))
            thresholds = convert_device(thresholds)
            indices = torch.ones_like(thresholds, dtype=torch.int64)
            for _ in range(self._depth):
                indices, thresholds = _step(indices, thresholds)

            is_small = indices < self._capacity
            num_small = is_small.to(torch.int64).sum()
            if num_small > 0:
                i = torch.where(is_small)[0]
                small_indices = indices[i]
                small_thresholds = thresholds[i]
                small_indices, _ = _step(small_indices, small_thresholds)
                indices[i] = small_indices

        return convert_device(self._leaf_to_index(indices))
Beispiel #3
0
    def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with PredictiveRepresentationLearnerInfo

        Note that the shape of experience is [B, T, ...]
        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        mini_batch_length = experience.step_type.shape[1]

        with alf.device(replay_buffer.device):
            # [B, 1]
            positions = convert_device(batch_info.positions).unsqueeze(-1)
            # [B, 1]
            env_ids = convert_device(batch_info.env_ids).unsqueeze(-1)

            # [B, T]
            positions = positions + torch.arange(mini_batch_length)

            # [B, T]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B, T]
            episode_end_positions = positions + steps_to_episode_end

            # [B, T, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1, 1]
            env_ids = env_ids.unsqueeze(-1)
            # [B, T, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            # [B, T, unroll_steps+1]
            mask = positions <= episode_end_positions

            # [B, T, unroll_steps+1]
            positions = torch.min(positions, episode_end_positions)

            # [B, T, unroll_steps+1]
            target = replay_buffer.get_field(self._target_fields, env_ids,
                                             positions)

            # [B, T, unroll_steps+1]
            action = replay_buffer.get_field('action', env_ids, positions)

            rollout_info = PredictiveRepresentationLearnerInfo(action=action,
                                                               mask=mask,
                                                               target=target)

        rollout_info = convert_device(rollout_info)

        return experience._replace(rollout_info=rollout_info)
Beispiel #4
0
    def __getitem__(self, idx):
        """Get the values of leaves.

        Args:
            idx (Tensor): 1-D int64 Tensor. Its values should be in range
                [0, capacity).
        Returns:
            Tensor: with same shaps as idx.
        """
        with alf.device(self._device):
            idx = convert_device(idx)
            assert 0 <= idx.min()
            assert idx.max() < self._capacity
            result = self._values[self._index_to_leaf(idx)]
        return convert_device(result)
Beispiel #5
0
    def _dequeue(self, env_ids=None, n=1):
        """Return earliest ``n`` steps and mark them removed in the buffer.

        Args:
            env_ids (Tensor): If None, ``batch_size`` must be num_environments.
                If not None, dequeue from these environments. We assume there
                is no duplicate ids in ``env_id``. ``result[i]`` will be from
                environment env_ids[i].
            n (int): Number of steps to dequeue.
        Returns:
            nested Tensors of shape ``[batch_size, n, ...]``.
        Raises:
            AssertionError: when not enough data is present.
        """
        with alf.device(self._device):
            env_ids = self.check_convert_env_ids(env_ids)
            current_size = self._current_size[env_ids]
            min_size = current_size.min()
            assert min_size >= n, (
                "Not all environments have enough data. The smallest data "
                "size is: %s Try storing more data before calling dequeue" %
                min_size)
            batch_size = env_ids.shape[0]
            pos = self._current_pos[env_ids] - current_size  # mod done later
            b_indices = env_ids.reshape(batch_size, 1).expand(-1, n)
            t_range = torch.arange(n).reshape(1, -1)
            t_indices = self.circular(pos.reshape(batch_size, 1) + t_range)
            result = alf.nest.map_structure(
                lambda b: b[(b_indices, t_indices)], self._buffer)
            self._current_size[env_ids] = current_size - n
            # set flags if they exist to unblock potential consumers
            if self._dequeued:
                self._dequeued.set()
                self._enqueued.clear()
        return convert_device(result)
Beispiel #6
0
    def gather_all(self):
        """Returns all the items in the buffer.

        Returns:
            Tensors of shape [B, T, ...], B=num_environments, T=current_size
        Raises:
            AssertionError: if the current_size is not same for all the
            environments.
        """
        size = self._current_size.min()
        max_size = self._current_size.max()
        assert size == max_size, (
            "Not all environments have the same size. min_size: %s "
            "max_size: %s" % (size, max_size))
        if size < self._max_length:
            pos = self._current_pos.min()
            max_pos = self._current_pos.max()
            assert pos == max_pos, (
                "Not all environments have the same ending position. "
                "min_pos: %s max_pos: %s" % (pos, max_pos))
            assert size == pos, (
                "When buffer not full, ending position of the data in the "
                "buffer current_pos coincides with current_size")

        # NOTE: this is not the proper way to gather all from a ring
        # buffer whose data can start from the middle, so this is limited
        # to the case where clear() is the only way to remove data from
        # the buffer.
        if size == self._max_length:
            result = self._buffer
        else:
            # Assumes that non-full buffer always stores data starting from 0
            result = alf.nest.map_structure(lambda buf: buf[:, :size, ...],
                                            self._buffer)
        return convert_device(result)
Beispiel #7
0
    def __setitem__(self, indices, values):
        """Set the value of leaves and update the internal nodes.

        Args:
            indices (Tensor): 1-D int64 Tensor. Its values should be in range
                [0, capacity).
            values (Tensor): 1-D Tensor with the same shape as ``indices``
        """

        def _step(indices):
            """
            Calculate the parent value from its children.
            """
            indices = indices / 2
            indices = torch.unique(indices)
            left = self._values[indices * 2]
            right = self._values[indices * 2 + 1]
            self._values[indices] = op(left, right)
            return indices

        with alf.device(self._device):
            indices = convert_device(indices)
            values = convert_device(values)

            assert indices.ndim == 1
            assert values.ndim == 1
            assert indices.shape == values.shape, (
                "indices and values should be 1-D tensor with the same length. "
                "Got %s and %s." % (indices.shape, values.shape))
            op = self._op
            indices, order = torch.sort(indices)
            values = values[order]
            assert indices[-1] < self._capacity
            indices = self._index_to_leaf(indices)
            self._values[indices] = values

            num_large = (indices >= self._leftmost_leaf).to(torch.int64).sum()
            if num_large > 0:
                large_indices = indices[:num_large]
                small_indices = indices[num_large:]
                large_indices = _step(large_indices)
                indices = torch.cat([large_indices, small_indices])

            for _ in range(self._depth):
                indices = _step(indices)
Beispiel #8
0
    def initial_priority(self):
        """The initial priority used for newly added experiences.

        We use a large value for initial priority so that a new experience can
        be used for training sooner. We make it at least 1.0 so that it can never
        be very small.
        """
        return convert_device(
            torch.max(self._max_tree.summary(), self._initial_priority))
Beispiel #9
0
    def get_batch_by_indices(self, indices):
        r"""Get the samples by indices

        index=0 corresponds to the earliest added sample in the DataBuffer.

        Args:
            indices (Tensor): indices of the samples

        Returns:
            Tensor:
            Tensor of shape ``[batch_size] + tensor_spec.shape``, where
            ``batch_size`` is ``indices.shape[0]``
        """
        with alf.device(self._device):
            indices = convert_device(indices)
            indices = self.circular(indices + self.current_pos -
                                    self.current_size)
            result = alf.nest.map_structure(lambda buf: buf[indices],
                                            self._derived_buffer)
        return convert_device(result)
Beispiel #10
0
    def summary(self):
        """The summary of the tree.

        If ``op`` is ``torch.add``, it's the sum of all values.
        If ``op`` is ``torch.min``, it's the min of all values.
        If ``op`` is ``torch.max``, it's the max of all values.

        Returns:
            a scalar
        """
        return convert_device(self._values[1])
Beispiel #11
0
 def check_convert_env_ids(self, env_ids):
     with alf.device(self._device):
         if env_ids is None:
             env_ids = torch.arange(self._num_envs)
         else:
             env_ids = env_ids.to(torch.int64)
         env_ids = convert_device(env_ids)
         assert len(env_ids.shape
                    ) == 1, "env_ids {}, should be a 1D tensor".format(
                        env_ids.shape)
         return env_ids
Beispiel #12
0
    def get_batch(self, batch_size):
        r"""Get batsh_size random samples in the buffer.

        Args:
            batch_size (int): batch size
        Returns:
            Tensor of shape ``[batch_size] + tensor_spec.shape``
        """
        with alf.device(self._device):
            indices = torch.randint(low=0,
                                    high=self.current_size,
                                    size=(batch_size, ),
                                    dtype=torch.int64)
            result = self.get_batch_by_indices(indices)
        return convert_device(result)
Beispiel #13
0
    def _prepare_reanalyze_data(self, replay_buffer: ReplayBuffer, env_ids,
                                positions, n1, n2):
        """
        Get the n1 + n2 steps of experience indicated by ``positions`` and return
        as the first n1 as ``exp1`` and the next n2 steps as ``exp2``.
        """
        batch_size = env_ids.shape[0]
        n = n1 + n2
        flat_env_ids = env_ids.expand_as(positions).reshape(-1)
        flat_positions = positions.reshape(-1)
        exp = replay_buffer.get_field(None, flat_env_ids, flat_positions)

        if self._data_transformer_ctor is not None:
            if self._data_transformer is None:
                observation_spec = dist_utils.extract_spec(exp.observation)
                self._data_transformer = create_data_transformer(
                    self._data_transformer_ctor, observation_spec)

            # DataTransformer assumes the shape of exp is [B, T, ...]
            # It also needs exp.batch_info and exp.replay_buffer.
            exp = alf.nest.map_structure(lambda x: x.unsqueeze(1), exp)
            exp = exp._replace(batch_info=BatchInfo(flat_env_ids,
                                                    flat_positions),
                               replay_buffer=replay_buffer)
            exp = self._data_transformer.transform_experience(exp)
            exp = exp._replace(batch_info=(), replay_buffer=())
            exp = alf.nest.map_structure(lambda x: x.squeeze(1), exp)

        def _split1(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, :n1, ...].reshape(batch_size * n1, *shape)

        def _split2(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, n1:, ...].reshape(batch_size * n2, *shape)

        with alf.device(self._device):
            exp = convert_device(exp)
            exp1 = alf.nest.map_structure(_split1, exp)
            exp2 = alf.nest.map_structure(_split2, exp)

        return exp1, exp2
Beispiel #14
0
    def _reanalyze(self, replay_buffer: ReplayBuffer, env_ids, positions,
                   mcts_state_field):
        batch_size = env_ids.shape[0]
        mini_batch_size = batch_size
        if self._reanalyze_batch_size is not None:
            mini_batch_size = self._reanalyze_batch_size

        result = []
        for i in range(0, batch_size, mini_batch_size):
            # Divide into several batches so that memory is enough.
            result.append(
                self._reanalyze1(replay_buffer, env_ids[i:i + mini_batch_size],
                                 positions[i:i + mini_batch_size],
                                 mcts_state_field))

        if len(result) == 1:
            result = result[0]
        else:
            result = alf.nest.map_structure(
                lambda *tensors: torch.cat(tensors), *result)
        return convert_device(result)
Beispiel #15
0
    def _enqueue(self, batch, env_ids=None):
        """Add a batch of items to the buffer (atomic).

        Args:
            batch (Tensor): shape should be
                ``[batch_size] + tensor_spec.shape``.
            env_ids (Tensor): If ``None``, ``batch_size`` must be
                ``num_environments``. If not ``None``, its shape should be
                ``[batch_size]``. We assume there are no duplicate ids in
                ``env_id``. ``batch[i]`` is generated by environment
                ``env_ids[i]``.
        """
        batch_size = alf.nest.get_nest_batch_size(batch)
        with alf.device(self._device):
            batch = convert_device(batch)
            env_ids = self.check_convert_env_ids(env_ids)
            assert batch_size == env_ids.shape[0], (
                "batch and env_ids do not have same length %s vs. %s" %
                (batch_size, env_ids.shape[0]))

            # Make sure that there is no duplicate in `env_id`
            # torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item
            assert torch.unique(
                env_ids, return_counts=True)[1].max() == 1, (
                    "There are duplicated ids in env_ids %s" % env_ids)

            current_pos = self._current_pos[env_ids]
            indices = env_ids * self._max_length + self.circular(current_pos)
            alf.nest.map_structure(
                lambda buf, bat: buf.__setitem__(indices, bat.detach()),
                self._flattened_buffer, batch)

            self._current_pos[env_ids] = current_pos + 1
            current_size = self._current_size[env_ids]
            self._current_size[env_ids] = torch.clamp(
                current_size + 1, max=self._max_length)
            # set flags if they exist to unblock potential consumers
            if self._enqueued:
                self._enqueued.set()
                self._dequeued.clear()
Beispiel #16
0
    def get_field(self, field_name, env_ids, positions):
        """Get stored data of field from the replay buffer by ``env_ids`` and ``positions``.

        Args:
            field_name (str | nest of str): indicate the path to the field with
                '.' separating the field name at different level
            env_ids (Tensor): 1-D int64 Tensor.
            positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``.
                These positions should be obtained from the BatchInfo returned
                by ``get_batch()``.
        Returns:
            Tensor: with the same shape as broadcasted shape of env_ids and positions
        """
        current_pos = self._current_pos[env_ids]
        assert torch.all(positions < current_pos), "Invalid positions"
        assert torch.all(
            positions >= current_pos - self._max_length), "Invalid positions"
        field = alf.nest.map_structure(
            lambda name: alf.nest.get_field(self._buffer, name), field_name)
        indices = (env_ids, self.circular(positions))
        result = alf.nest.map_structure(lambda x: x[indices], field)
        return convert_device(result)
Beispiel #17
0
 def _stack_frame(obs, i):
     prev_obs = replay_buffer.get_field(self._exp_fields[i], env_ids,
                                        prev_positions)
     prev_obs = convert_device(prev_obs)
     stacked_shape = alf.nest.get_field(
         self._transformed_observation_spec, self._fields[i]).shape
     # [batch_size, mini_batch_length + stack_size - 1, ...]
     stacked_obs = torch.cat((prev_obs, obs), dim=1)
     # [batch_size, mini_batch_length, stack_size, ...]
     stacked_obs = stacked_obs[obs_index]
     if self._stack_axis != 0 and obs.ndim > 3:
         stack_axis = self._stack_axis
         if stack_axis < 0:
             stack_axis += stacked_obs.ndim
         else:
             stack_axis += 3
         stacked_obs = stacked_obs.unsqueeze(stack_axis)
         stacked_obs = stacked_obs.transpose(2, stack_axis)
         stacked_obs = stacked_obs.squeeze(2)
     stacked_obs = stacked_obs.reshape(batch_size, mini_batch_length,
                                       *stacked_shape)
     return stacked_obs
Beispiel #18
0
    def add_batch(self, batch):
        r"""Add a batch of items to the buffer.

        Add batch_size items along the length of the underlying RingBuffer,
        whereas RingBuffer.enqueue only adds data of length 1.
        Truncates the data if ``batch_size > capacity``.

        Args:
            batch (Tensor): of shape ``[batch_size] + tensor_spec.shape``
        """
        batch_size = alf.nest.get_nest_batch_size(batch)
        with alf.device(self._device):
            batch = convert_device(batch)
            n = torch.clamp(self._capacity, max=batch_size)
            current_pos = self.current_pos
            current_size = self.current_size
            indices = self.circular(torch.arange(current_pos, current_pos + n))
            alf.nest.map_structure(
                lambda buf, bat: buf.__setitem__(indices, bat[-n:].detach()),
                self._derived_buffer, batch)

            current_pos.copy_(current_pos + n)
            current_size.copy_(torch.min(current_size + n, self._capacity))
Beispiel #19
0
    def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with MuzeroInfo

        Note that the shape of experience is [B, T, ...]
        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        info_path: str = experience.rollout_info_field
        mini_batch_length = experience.step_type.shape[1]
        assert mini_batch_length == 1, (
            "Only support TrainerConfig.mini_batch_length=1, got %s" %
            mini_batch_length)

        value_field = info_path + '.value'
        candidate_actions_field = info_path + '.candidate_actions'
        candidate_action_policy_field = (info_path +
                                         '.candidate_action_policy')

        with alf.device(replay_buffer.device):
            positions = convert_device(batch_info.positions)  # [B]
            env_ids = convert_device(batch_info.env_ids)  # [B]

            if self._reanalyze_ratio > 0:
                # Here we assume state and info have similar name scheme.
                mcts_state_field = 'state' + info_path[len('rollout_info'):]
                r = torch.rand(
                    experience.step_type.shape[0]) < self._reanalyze_ratio
                r_candidate_actions, r_candidate_action_policy, r_values = self._reanalyze(
                    replay_buffer, env_ids[r], positions[r], mcts_state_field)

            # [B]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B]
            episode_end_positions = positions + steps_to_episode_end

            # [B, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1]
            env_ids = batch_info.env_ids.unsqueeze(-1)
            # [B, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            beyond_episode_end = positions > episode_end_positions
            positions = torch.min(positions, episode_end_positions)

            if self._td_steps >= 0:
                values = self._calc_bootstrap_return(replay_buffer, env_ids,
                                                     positions, value_field)
            else:
                values = self._calc_monte_carlo_return(replay_buffer, env_ids,
                                                       positions, value_field)

            candidate_actions = replay_buffer.get_field(
                candidate_actions_field, env_ids, positions)
            candidate_action_policy = replay_buffer.get_field(
                candidate_action_policy_field, env_ids, positions)

            if self._reanalyze_ratio > 0:
                if not _is_empty(candidate_actions):
                    candidate_actions[r] = r_candidate_actions
                candidate_action_policy[r] = r_candidate_action_policy
                values[r] = r_values

            game_overs = ()
            if self._train_game_over_function or self._train_reward_function:
                game_overs = positions == episode_end_positions
                discount = replay_buffer.get_field('discount', env_ids,
                                                   positions)
                # In the case of discount != 0, the game over may not always be correct
                # since the episode is truncated because of TimeLimit or incomplete
                # last episode in the replay buffer. There is no way to know for sure
                # the future game overs.
                game_overs = game_overs & (discount == 0.)

            rewards = ()
            if self._train_reward_function:
                rewards = self._get_reward(replay_buffer, env_ids, positions)
                rewards[beyond_episode_end] = 0.
                values[game_overs] = 0.

            if not self._train_game_over_function:
                game_overs = ()

            action = replay_buffer.get_field('action', env_ids,
                                             positions[:, :-1])

            rollout_info = MuzeroInfo(
                action=action,
                value=(),
                target=ModelTarget(reward=rewards,
                                   action=candidate_actions,
                                   action_policy=candidate_action_policy,
                                   value=values,
                                   game_over=game_overs))

        # make the shape to [B, T, ...], where T=1
        rollout_info = alf.nest.map_structure(lambda x: x.unsqueeze(1),
                                              rollout_info)
        rollout_info = convert_device(rollout_info)
        rollout_info = rollout_info._replace(
            value=experience.rollout_info.value)

        if self._reward_normalizer:
            experience = experience._replace(
                reward=rollout_info.target.reward[:, :, 0])
        return experience._replace(rollout_info=rollout_info)
Beispiel #20
0
    def _reanalyze1(self, replay_buffer: ReplayBuffer, env_ids, positions,
                    mcts_state_field):
        """Reanalyze one batch.

        This means:
        1. Re-plan the policy using MCTS for n1 = 1 + num_unroll_steps to get fresh policy
        and value target.
        2. Caluclate the value for following n2 = reanalyze_td_steps so that we have value
        for a total of 1 + num_unroll_steps + reanalyze_td_steps.
        3. Use these values and rewards from replay buffer to caculate n2-step
        bootstraped value target for the first n1 steps.

        In order to do 1 and 2, we need to get the observations for n1 + n2 steps
        and processs them using data_transformer.
        """
        batch_size = env_ids.shape[0]
        n1 = self._num_unroll_steps + 1
        n2 = self._reanalyze_td_steps
        env_ids, positions = self._next_n_positions(
            replay_buffer, env_ids, positions, self._num_unroll_steps + n2)
        # [B, n1]
        positions1 = positions[:, :n1]
        game_overs = replay_buffer.get_field('discount', env_ids,
                                             positions1) == 0.

        steps_to_episode_end = replay_buffer.steps_to_episode_end(
            positions1, env_ids)
        bootstrap_n = steps_to_episode_end.clamp(max=n2)

        exp1, exp2 = self._prepare_reanalyze_data(replay_buffer, env_ids,
                                                  positions, n1, n2)

        bootstrap_position = positions1 + bootstrap_n
        discount = replay_buffer.get_field('discount', env_ids,
                                           bootstrap_position)
        sum_reward = self._sum_discounted_reward(replay_buffer, env_ids,
                                                 positions1,
                                                 bootstrap_position, n2)

        if not self._train_reward_function:
            rewards = self._get_reward(replay_buffer, env_ids,
                                       bootstrap_position)

        with alf.device(self._device):
            bootstrap_n = convert_device(bootstrap_n)
            discount = convert_device(discount)
            sum_reward = convert_device(sum_reward)
            game_overs = convert_device(game_overs)

            # 1. Reanalyze the first n1 steps to get both the updated value and policy
            self._mcts.set_model(self._target_model)
            mcts_step = self._mcts.predict_step(
                exp1, alf.nest.get_field(exp1, mcts_state_field))
            self._mcts.set_model(self._model)
            candidate_actions = ()
            if not _is_empty(mcts_step.info.candidate_actions):
                candidate_actions = mcts_step.info.candidate_actions
                candidate_actions = candidate_actions.reshape(
                    batch_size, n1, *candidate_actions.shape[1:])
            candidate_action_policy = mcts_step.info.candidate_action_policy
            candidate_action_policy = candidate_action_policy.reshape(
                batch_size, n1, *candidate_action_policy.shape[1:])
            values1 = mcts_step.info.value.reshape(batch_size, n1)

            # 2. Calulate the value of the next n2 steps so that n2-step return
            # can be computed.
            model_output = self._target_model.initial_inference(
                exp2.observation)
            values2 = model_output.value.reshape(batch_size, n2)

            # 3. Calculate n2-step return
            values = torch.cat([values1, values2], dim=1)
            # [B, n1]
            bootstrap_pos = torch.arange(n1).unsqueeze(0) + bootstrap_n
            values = values[torch.arange(batch_size).unsqueeze(-1),
                            bootstrap_pos]
            values = values * discount * (self._discount**bootstrap_n.to(
                torch.float32))
            values = values + sum_reward
            if not self._train_reward_function:
                # For this condition, we need to set the value at and after the
                # last step to be the last reward.
                values = torch.where(game_overs, convert_device(rewards),
                                     values)
            return candidate_actions, candidate_action_policy, values
Beispiel #21
0
    def transform_experience(self, experience: Experience):
        if self._stack_size == 1:
            return experience

        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer

        with alf.device(replay_buffer.device):
            # [B]
            env_ids = convert_device(batch_info.env_ids)
            # [B]
            positions = convert_device(batch_info.positions)

            prev_positions = torch.arange(self._stack_size -
                                          1) - self._stack_size + 1

            # [B, stack_size - 1]
            prev_positions = positions.unsqueeze(
                -1) + prev_positions.unsqueeze(0)
            episode_begin_positions = replay_buffer.get_episode_begin_position(
                positions, env_ids)
            # [B, 1]
            episode_begin_positions = episode_begin_positions.unsqueeze(-1)
            # [B, stack_size - 1]
            prev_positions = torch.max(prev_positions, episode_begin_positions)
            # [B, 1]
            env_ids = env_ids.unsqueeze(-1)
            assert torch.all(
                prev_positions[:, 0] >= replay_buffer.get_earliest_position(
                    env_ids)
            ), ("Some previous posisions are no longer in the replay buffer")

        batch_size, mini_batch_length = experience.step_type.shape

        # [[0, 1, ..., stack_size-1],
        #  [1, 2, ..., stack_size],
        #  ...
        #  [mini_batch_length - 1, ...]]
        #
        # [mini_batch_length, stack_size]
        obs_index = (torch.arange(self._stack_size).unsqueeze(0) +
                     torch.arange(mini_batch_length).unsqueeze(1))
        B = torch.arange(batch_size)
        obs_index = (B.unsqueeze(-1).unsqueeze(-1), obs_index.unsqueeze(0))

        def _stack_frame(obs, i):
            prev_obs = replay_buffer.get_field(self._exp_fields[i], env_ids,
                                               prev_positions)
            prev_obs = convert_device(prev_obs)
            stacked_shape = alf.nest.get_field(
                self._transformed_observation_spec, self._fields[i]).shape
            # [batch_size, mini_batch_length + stack_size - 1, ...]
            stacked_obs = torch.cat((prev_obs, obs), dim=1)
            # [batch_size, mini_batch_length, stack_size, ...]
            stacked_obs = stacked_obs[obs_index]
            if self._stack_axis != 0 and obs.ndim > 3:
                stack_axis = self._stack_axis
                if stack_axis < 0:
                    stack_axis += stacked_obs.ndim
                else:
                    stack_axis += 3
                stacked_obs = stacked_obs.unsqueeze(stack_axis)
                stacked_obs = stacked_obs.transpose(2, stack_axis)
                stacked_obs = stacked_obs.squeeze(2)
            stacked_obs = stacked_obs.reshape(batch_size, mini_batch_length,
                                              *stacked_shape)
            return stacked_obs

        observation = experience.observation
        for i, field in enumerate(self._fields):
            observation = alf.nest.transform_nest(observation, field,
                                                  partial(_stack_frame, i=i))
        return experience._replace(observation=observation)
Beispiel #22
0
    def add_batch(self, batch, env_ids=None, blocking=False):
        """Add a batch of entries to buffer updating indices as needed.

        We build an index of episode beginning indices for each element
        in the buffer.  The beginning point stores where episode end is.

        Args:
            batch (Tensor): of shape ``[batch_size] + tensor_spec.shape``
            env_ids (Tensor): If ``None``, ``batch_size`` must be
                ``num_environments``. If not ``None``, its shape should be
                ``[batch_size]``. We assume there are no duplicate ids in
                ``env_id``. ``batch[i]`` is generated by environment
                ``env_ids[i]``.
            blocking (bool): If ``True``, blocks if there is no free slot to add
                data.  If ``False``, enqueue can overwrite oldest data.
        """
        with alf.device(self._device):
            env_ids = self.check_convert_env_ids(env_ids)
            if self._keep_episodic_info:
                assert not blocking, (
                    "HER replay buffer doesn't wait for dequeue to free up " +
                    "space, but instead just overwrites.")
                batch = convert_device(batch)
                # 1. save episode beginning data that will be overwritten
                overwriting_pos = self._current_pos[env_ids]
                buffer_step_types = alf.nest.get_field(self._buffer,
                                                       self._step_type_field)
                first, = torch.where(
                    (buffer_step_types[(env_ids,
                                        self.circular(overwriting_pos))]
                     == ds.StepType.FIRST) *
                    (self._current_size[env_ids] == self._max_length))
                first_env_ids = env_ids[first]
                first_step_idx = self.circular(overwriting_pos[first])
                self._headless_indexed_pos[first_env_ids] = self._indexed_pos[(
                    first_env_ids, first_step_idx)]

            # 2. enqueue batch
            self.enqueue(batch, env_ids, blocking=blocking)
            if self._prioritized_sampling:
                self._initialize_priority(env_ids)
                if self._num_earliest_frames_ignored > 0:
                    # Make sure the priortized sampling ignores the earliest
                    # frames by setting their priorities to 0.
                    current_pos = self._current_pos[env_ids]
                    pos = current_pos - self._current_size[env_ids]
                    pos = pos + self._num_earliest_frames_ignored - 1
                    pos = torch.min(pos, current_pos - 1)
                    self.update_priority(
                        env_ids, pos, torch.zeros_like(pos,
                                                       dtype=torch.float32))

            if self._keep_episodic_info:
                # 3. Update associated episode end indices
                # 3.1. find ending steps in batch (incl. MID and LAST steps)
                step_types = alf.nest.get_field(batch, self._step_type_field)
                non_first, = torch.where(step_types != ds.StepType.FIRST)
                # 3.2. update episode ending positions
                self._store_episode_end_pos(non_first, overwriting_pos,
                                            env_ids)
                # 3.3. initialize episode beginning positions to itself
                epi_first, = torch.where(step_types == ds.StepType.FIRST)
                self._indexed_pos[(env_ids[epi_first],
                                   self.circular(overwriting_pos[epi_first])
                                   )] = overwriting_pos[epi_first]
Beispiel #23
0
 def get_all(self):
     return convert_device(
         alf.nest.map_structure(lambda buf: buf, self._derived_buffer))
Beispiel #24
0
    def get_batch(self, batch_size, batch_length):
        """Randomly get ``batch_size`` trajectories from the buffer.

        It could hindsight relabel the experience via postprocess_exp_fn.

        Note: The environments where the sampels are from are ordered in the
            returned batch.

        Args:
            batch_size (int): get so many trajectories
            batch_length (int): the length of each trajectory
        Returns:
            tuple:
                - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...]
                - BatchInfo: Information about the batch. Its shapes are [batch_size].
                    - env_ids: environment id for each sequence
                    - positions: starting position in the replay buffer for each sequence.
                    - importance_weights: priority divided by the average of all
                        non-zero priorities in the buffer.
        """
        with alf.device(self._device):
            recent_batch_size = 0
            if self._recent_data_ratio > 0:
                d = batch_length - 1 + self._num_earliest_frames_ignored
                avg_size = self.total_size / float(self._num_envs) - d
                if (avg_size * self._recent_data_ratio >
                        self._recent_data_steps):
                    # If this condition is False, regular sampling without considering
                    # recent data will get enough samples from recent data. So
                    # we don't need to have a separate step just for sampling from
                    # the recent data.
                    recent_batch_size = math.ceil(batch_size *
                                                  self._recent_data_ratio)

            normal_batch_size = batch_size - recent_batch_size
            if self._prioritized_sampling:
                info = self._prioritized_sample(normal_batch_size,
                                                batch_length)
            else:
                info = self._uniform_sample(normal_batch_size, batch_length)

            if recent_batch_size > 0:
                # Note that _uniform_sample() get samples duplicated with those
                # from _recent_sample()
                recent_info = self._recent_sample(recent_batch_size,
                                                  batch_length)
                info = alf.nest.map_structure(lambda *x: torch.cat(x),
                                              recent_info, info)

            start_pos = info.positions
            env_ids = info.env_ids

            idx = start_pos.reshape(-1, 1)  # [B, 1]
            idx = self.circular(
                idx + torch.arange(batch_length).unsqueeze(0))  # [B, T]
            out_env_ids = env_ids.reshape(-1,
                                          1).expand(batch_size,
                                                    batch_length)  # [B, T]
            result = alf.nest.map_structure(lambda b: b[(out_env_ids, idx)],
                                            self._buffer)

            if alf.summary.should_record_summaries():
                alf.summary.scalar(
                    "replayer/" + self._name + ".original_reward_mean",
                    torch.mean(result.reward[:-1]))

            if self._postprocess_exp_fn:
                result, info = self._postprocess_exp_fn(self, result, info)

        if alf.get_default_device() == self._device:
            return result, info
        else:
            return convert_device(result), convert_device(info)
Beispiel #25
0
 def total_size(self):
     """Total size from all environments."""
     return convert_device(self._current_size.sum())