Esempio n. 1
0
    def __init__(self, capacity=1000, alpha=1.0, beta=1.0):
        """
        Args:
            capacity (int): Max capacity.
            alpha (float): Initial weight.
            beta (float): Prioritisation factor.
        """
        super(ApexMemory, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity
        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta

        self.default_new_weight = np.power(self.max_priority, self.alpha)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)
        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)
Esempio n. 2
0
    def create_variables(self, input_spaces, action_space=None):
        super(MemPrioritizedReplay, self).create_variables(input_spaces, action_space)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity, operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity, min)

        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity
        )
Esempio n. 3
0
    def __init__(self,
                 capacity=1000,
                 alpha=1.0,
                 beta=1.0,
                 n_step_adjustment=1):
        """
        TODO: documentation.
        Args:
            capacity ():
            alpha ():
            beta ():
            n_step_adjustment ():
        """
        super(ApexMemory, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity
        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta
        # TODO this is not used here any more
        # assert n_step_adjustment > 0, "ERROR: n-step adjustment must be at least 1 where 1 corresponds" \
        #     "to the direct next state."
        self.n_step_adjustment = n_step_adjustment

        self.default_new_weight = np.power(self.max_priority, self.alpha)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)
        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)
Esempio n. 4
0
    def create_variables(self, input_spaces, action_space=None):
        # Store our record-space for convenience.
        self.record_space = input_spaces["records"]
        self.record_space_flat = Dict(self.record_space.flatten(
            custom_scope_separator="/", scope_separator_at_start=False),
                                      add_batch_rank=True)
        self.priority_capacity = 1

        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)

        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)
Esempio n. 5
0
class LegacyApexMemory(Specifiable):
    """
    Apex prioritized replay implementing compression.
    """
    def __init__(self,
                 capacity=1000,
                 alpha=1.0,
                 beta=1.0,
                 n_step_adjustment=1):
        """
        TODO: documentation.
        Args:
            capacity ():
            alpha ():
            beta ():
            n_step_adjustment ():
        """
        super(LegacyApexMemory, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity
        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta
        assert n_step_adjustment > 0, "ERROR: n-step adjustment must be at least 1 where 1 corresponds" \
            "to the direct next state."
        self.n_step_adjustment = n_step_adjustment

        self.default_new_weight = np.power(self.max_priority, self.alpha)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)
        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)

    def insert_records(self, record):
        # TODO: This has the record interface, but actually expects a specific structure anyway, so
        # may as well change API?
        if self.index >= self.size:
            self.memory_values.append(record)
        else:
            self.memory_values[self.index] = record

        # Weights. # TODO this is problematic due to index not existing.
        if record[4] is not None:
            self.merged_segment_tree.insert(self.index, record[4])
        else:
            self.merged_segment_tree.insert(self.index,
                                            self.default_new_weight)

        # Update indices.
        self.index = (self.index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def read_records(self, indices):
        """
        Obtains record values for the provided indices.

        Args:
            indices ndarray: Indices to read. Assumed to be not contiguous.

        Returns:
             dict: Record value dict.
        """
        states = list()
        actions = list()
        rewards = list()
        terminals = list()
        next_states = list()
        for index in indices:
            # TODO remove as array casts if they dont help.
            state, action, reward, terminal, weight = self.memory_values[index]
            decompressed_state = np.array(ray_decompress(state), copy=False)
            states.append(decompressed_state)
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            terminals.append(terminal)

            decompressed_next_state = decompressed_state
            # If terminal -> just use same state, already decompressed
            if terminal:
                next_states.append(decompressed_next_state)
            else:
                # Otherwise advance until correct next state or terminal.
                next_state = decompressed_next_state
                for i in range_(self.n_step_adjustment):
                    next_index = (index + i + 1) % self.size
                    next_state, _, _, terminal, _ = self.memory_values[
                        next_index]
                    if terminal:
                        break
                next_states.append(
                    np.array(ray_decompress(next_state), copy=False))

        return dict(states=np.array(states),
                    actions=np.array(actions),
                    rewards=np.array(rewards),
                    terminals=np.array(terminals),
                    next_states=np.array(next_states))

    def get_records(self, num_records):
        indices = []
        # Ensure we always have n-next states.
        # TODO potentially block this if size - 1 - nstep < 1?
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(
            0, self.size - 1 - self.n_step_adjustment)
        samples = np.random.random(size=(num_records, )) * prob_sum
        for sample in samples:
            indices.append(
                self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(
                    prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum()
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value(
        ) / sum_prob
        max_weight = (min_prob * self.size)**(-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(
                index) / sum_prob
            weight = (sample_prob * self.size)**(-self.beta)
            weights.append(weight / max_weight)

        indices = np.array(indices, copy=False)
        return self.read_records(indices=indices), indices, np.array(
            weights, copy=False)

    def update_records(self, indices, update):
        for index, loss in zip(indices, update):
            priority = np.power(loss, self.alpha)
            self.merged_segment_tree.insert(index, priority)
            self.max_priority = max(self.max_priority, priority)
Esempio n. 6
0
class ApexMemory(Specifiable):
    """
    Apex prioritized replay implementing compression.
    """
    def __init__(self,
                 capacity=1000,
                 alpha=1.0,
                 beta=1.0,
                 n_step_adjustment=1):
        """
        TODO: documentation.
        Args:
            capacity ():
            alpha ():
            beta ():
            n_step_adjustment ():
        """
        super(ApexMemory, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity
        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta
        # TODO this is not used here any more
        # assert n_step_adjustment > 0, "ERROR: n-step adjustment must be at least 1 where 1 corresponds" \
        #     "to the direct next state."
        self.n_step_adjustment = n_step_adjustment

        self.default_new_weight = np.power(self.max_priority, self.alpha)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)
        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)

    def insert_records(self, record):
        # TODO: This has the record interface, but actually expects a specific structure anyway, so
        # may as well change API?
        if self.index >= self.size:
            self.memory_values.append(record)
        else:
            self.memory_values[self.index] = record

        # Weights. # TODO this is problematic due to index not existing.
        if record[5] is not None:
            self.merged_segment_tree.insert(self.index, record[5]**self.alpha)
        else:
            self.merged_segment_tree.insert(self.index,
                                            self.max_priority**self.alpha)

        # Update indices.
        self.index = (self.index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def read_records(self, indices):
        """
        Obtains record values for the provided indices.

        Args:
            indices ndarray: Indices to read. Assumed to be not contiguous.

        Returns:
             dict: Record value dict.
        """
        states = []
        actions = []
        rewards = []
        terminals = []
        next_states = []
        for index in indices:
            state, action, reward, terminal, next_state, weight = self.memory_values[
                index]
            states.append(ray_decompress(state))
            actions.append(action)
            rewards.append(reward)
            terminals.append(terminal)
            next_states.append(ray_decompress(next_state))

        return dict(states=np.asarray(states),
                    actions=np.asarray(actions),
                    rewards=np.asarray(rewards),
                    terminals=np.asarray(terminals),
                    next_states=np.asarray(next_states))

    def get_records(self, num_records):
        indices = []
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(
            0, self.size)
        samples = np.random.random(size=(num_records, )) * prob_sum
        for sample in samples:
            indices.append(
                self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(
                    prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum()
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value(
        ) / sum_prob + SMALL_NUMBER
        max_weight = (min_prob * self.size)**(-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(
                index) / sum_prob
            weight = (sample_prob * self.size)**(-self.beta)
            weights.append(weight / max_weight)

        return self.read_records(
            indices=indices), np.asarray(indices), np.asarray(weights)

    def update_records(self, indices, update):
        for index, loss in zip(indices, update):
            self.merged_segment_tree.insert(index, loss**self.alpha)
            self.max_priority = max(self.max_priority, loss)
Esempio n. 7
0
class MemPrioritizedReplay(Memory):
    """
    Implements an in-memory  prioritized replay.

    API:
        update_records(indices, update) -> Updates the given indices with the given priority scores.
    """
    def __init__(self, capacity=1000, next_states=True, alpha=1.0, beta=0.0):
        super(MemPrioritizedReplay, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity

        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta
        self.next_states = next_states

        self.default_new_weight = np.power(self.max_priority, self.alpha)

    def create_variables(self, input_spaces, action_space=None):
        super(MemPrioritizedReplay,
              self).create_variables(input_spaces, action_space)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)

        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)

    @rlgraph_api(flatten_ops=True)
    def _graph_fn_insert_records(self, records):
        if records is None or get_rank(records[self.terminal_key]) == 0:
            return
        num_records = len(records[self.terminal_key])

        if num_records == 1:
            if self.index >= self.size:
                self.memory_values.append(records)
            else:
                self.memory_values[self.index] = records
            self.merged_segment_tree.insert(self.index,
                                            self.default_new_weight)
        else:
            insert_indices = np.arange(
                start=self.index,
                stop=self.index + num_records) % self.capacity
            i = 0
            for insert_index in insert_indices:
                self.merged_segment_tree.insert(insert_index,
                                                self.default_new_weight)
                record = {}
                for name, record_values in records.items():
                    record[name] = record_values[i]
                if insert_index >= self.size:
                    self.memory_values.append(record)
                else:
                    self.memory_values[insert_index] = record
                i += 1

        # Update indices
        self.index = (self.index + num_records) % self.capacity
        self.size = min(self.size + num_records, self.capacity)

    @rlgraph_api
    def _graph_fn_get_records(self, num_records=1):
        available_records = min(num_records, self.size)
        indices = []
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(
            0, self.size - 1)
        samples = np.random.random(size=(available_records, )) * prob_sum
        for sample in samples:
            indices.append(
                self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(
                    prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum(
        ) + SMALL_NUMBER
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value(
        ) / sum_prob
        max_weight = (min_prob * self.size)**(-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(
                index) / sum_prob
            weight = (sample_prob * self.size)**(-self.beta)
            weights.append(weight / max_weight)

        if get_backend() == "pytorch":
            indices = torch.tensor(indices)
            weights = torch.tensor(weights)
        else:
            indices = np.asarray(indices)
            weights = np.asarray(weights)

        records = DataOpDict()
        for name, variable in self.memory.items():
            records[name] = self.read_variable(
                variable,
                indices,
                dtype=util.convert_dtype(self.flat_record_space[name].dtype,
                                         to="pytorch"))
        records = define_by_run_unflatten(records)
        return records, indices, weights

    @rlgraph_api(must_be_complete=False)
    def _graph_fn_update_records(self, indices, update):
        for index, loss in zip(indices, update):
            priority = np.power(loss, self.alpha)
            self.merged_segment_tree.insert(index, priority)
            self.max_priority = max(self.max_priority, priority)

    def get_state(self):
        return {
            "size": self.size,
            "index": self.index,
            "max_priority": self.max_priority
        }
Esempio n. 8
0
class MemPrioritizedReplay(Memory):
    """
    Implements an in-memory  prioritized replay.

    API:
        update_records(indices, update) -> Updates the given indices with the given priority scores.
    """
    def __init__(self, capacity=1000, next_states=True, alpha=1.0, beta=0.0):
        super(MemPrioritizedReplay, self).__init__()

        self.memory_values = []
        self.index = 0
        self.capacity = capacity

        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha

        self.beta = beta
        self.next_states = next_states

        self.default_new_weight = np.power(self.max_priority, self.alpha)

    def create_variables(self, input_spaces, action_space=None):
        # Store our record-space for convenience.
        self.record_space = input_spaces["records"]
        self.record_space_flat = Dict(self.record_space.flatten(
            custom_scope_separator="/", scope_separator_at_start=False),
                                      add_batch_rank=True)
        self.priority_capacity = 1

        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)

        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)

    def _read_records(self, indices):
        """
        Obtains record values for the provided indices.

        Args:
            indices ndarray: Indices to read. Assumed to be not contiguous.

        Returns:
             dict: Record value dict.
        """
        records = {}
        for name in self.record_space_flat.keys():
            records[name] = []

        if self.size > 0:
            for index in indices:
                record = self.memory_values[index]
                for name in self.record_space_flat.keys():
                    records[name].append(record[name])

        else:
            # TODO figure out how to do default handling in pytorch builds.
            # Fill with default vals for build.
            for name in self.record_space_flat.keys():
                if get_backend() == "pytorch":
                    records[name] = torch.zeros(
                        self.record_space_flat[name].shape,
                        dtype=dtype_(self.record_space_flat[name].dtype,
                                     "pytorch"))
                else:
                    records[name] = np.zeros(
                        self.record_space_flat[name].shape)

        # Convert if necessary: list of tensors fails at space inference otherwise.
        if get_backend() == "pytorch":
            for name in self.record_space_flat.keys():
                records[name] = torch.squeeze(torch.stack(records[name]))

        return records

    @rlgraph_api(flatten_ops=True)
    def _graph_fn_insert_records(self, records):
        if records is None or get_rank(records['rewards']) == 0:
            return
        num_records = len(records['rewards'])

        if num_records == 1:
            if self.index >= self.size:
                self.memory_values.append(records)
            else:
                self.memory_values[self.index] = records
            self.merged_segment_tree.insert(self.index,
                                            self.default_new_weight)
        else:
            insert_indices = np.arange(
                start=self.index,
                stop=self.index + num_records) % self.capacity
            i = 0
            for insert_index in insert_indices:
                self.merged_segment_tree.insert(insert_index,
                                                self.default_new_weight)
                record = dict()
                for name, record_values in records.items():
                    record[name] = record_values[i]
                if insert_index >= self.size:
                    self.memory_values.append(record)
                else:
                    self.memory_values[insert_index] = record
                i += 1

        # Update indices
        self.index = (self.index + num_records) % self.capacity
        self.size = min(self.size + num_records, self.capacity)

    @rlgraph_api
    def _graph_fn_get_records(self, num_records=1):
        indices = []
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(
            0, self.size - 1)
        samples = np.random.random(size=(num_records, )) * prob_sum
        for sample in samples:
            indices.append(
                self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(
                    prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum(
        ) + SMALL_NUMBER
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value(
        ) / sum_prob
        max_weight = (min_prob * self.size)**(-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(
                index) / sum_prob
            weight = (sample_prob * self.size)**(-self.beta)
            weights.append(weight / max_weight)

        if get_backend() == "pytorch":
            indices = torch.tensor(indices)
            weights = torch.tensor(weights)
        else:
            indices = np.asarray(indices)
            weights = np.asarray(weights)
        return self._read_records(indices=indices), indices, weights

    @rlgraph_api(must_be_complete=False)
    def _graph_fn_update_records(self, indices, update):
        if len(indices) > 0 and indices[0]:
            for index, loss in zip(indices, update):
                priority = np.power(loss, self.alpha)
                self.merged_segment_tree.insert(index, priority)
                self.max_priority = max(self.max_priority, priority)
Esempio n. 9
0
class ApexMemory(Specifiable):
    """
    Apex prioritized replay implementing compression.
    """
    def __init__(self,
                 state_space=None,
                 action_space=None,
                 capacity=1000,
                 alpha=1.0,
                 beta=1.0):
        """
        Args:
            state_space (dict): State spec.
            action_space (dict): Actions spec.
            capacity (int): Max capacity.
            alpha (float): Initial weight.
            beta (float): Prioritisation factor.
        """
        super(ApexMemory, self).__init__()

        self.state_space = state_space
        self.action_space = action_space
        self.container_actions = isinstance(action_space, dict)
        self.memory_values = []
        self.index = 0
        self.capacity = capacity
        self.size = 0
        self.max_priority = 1.0
        self.alpha = alpha
        self.beta = beta

        self.default_new_weight = np.power(self.max_priority, self.alpha)
        self.priority_capacity = 1
        while self.priority_capacity < self.capacity:
            self.priority_capacity *= 2

        # Create segment trees, initialize with neutral elements.
        sum_values = [0.0 for _ in range_(2 * self.priority_capacity)]
        sum_segment_tree = MemSegmentTree(sum_values, self.priority_capacity,
                                          operator.add)
        min_values = [float('inf') for _ in range_(2 * self.priority_capacity)]
        min_segment_tree = MemSegmentTree(min_values, self.priority_capacity,
                                          min)
        self.merged_segment_tree = MinSumSegmentTree(
            sum_tree=sum_segment_tree,
            min_tree=min_segment_tree,
            capacity=self.priority_capacity)

    def insert_records(self, record):
        # TODO: This has the record interface, but actually expects a specific structure anyway, so
        # may as well change API?
        if self.index >= self.size:
            self.memory_values.append(record)
        else:
            self.memory_values[self.index] = record

        # Weights. # TODO this is problematic due to index not existing.
        if record[5] is not None:
            self.merged_segment_tree.insert(self.index, record[5]**self.alpha)
        else:
            self.merged_segment_tree.insert(self.index,
                                            self.max_priority**self.alpha)

        # Update indices.
        self.index = (self.index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def read_records(self, indices):
        """
        Obtains record values for the provided indices.

        Args:
            indices (ndarray): Indices to read. Assumed to be not contiguous.

        Returns:
             dict: Record value dict.
        """
        states = []
        if self.container_actions:
            actions = {k: [] for k in self.action_space.keys()}
        else:
            actions = []
        rewards = []
        terminals = []
        next_states = []
        for index in indices:
            state, action, reward, terminal, next_state, weight = self.memory_values[
                index]
            states.append(ray_decompress(state))

            if self.container_actions:
                for name in self.action_space.keys():
                    actions[name].append(action[name])
            else:
                actions.append(action)
            rewards.append(reward)
            terminals.append(terminal)
            next_states.append(ray_decompress(next_state))

        if self.container_actions:
            for name in self.action_space.keys():
                actions[name] = np.squeeze(np.array(actions[name]))
        else:
            actions = np.array(actions)
        return dict(states=np.asarray(states),
                    actions=actions,
                    rewards=np.asarray(rewards),
                    terminals=np.asarray(terminals),
                    next_states=np.asarray(next_states))

    def get_records(self, num_records):
        indices = []
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(
            0, self.size)
        samples = np.random.random(size=(num_records, )) * prob_sum
        for sample in samples:
            indices.append(
                self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(
                    prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum()
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value(
        ) / sum_prob + SMALL_NUMBER
        max_weight = (min_prob * self.size)**(-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(
                index) / sum_prob
            weight = (sample_prob * self.size)**(-self.beta)
            weights.append(weight / max_weight)

        return self.read_records(
            indices=indices), np.asarray(indices), np.asarray(weights)

    def update_records(self, indices, update):
        for index, loss in zip(indices, update):
            self.merged_segment_tree.insert(index, loss**self.alpha)
            self.max_priority = max(self.max_priority, loss)