Beispiel #1
0
    def _graph_fn_insert_records(self, records):
        num_records = get_batch_size(records[self.terminal_key])
        if get_backend() == "tf":
            # List of indices to update (insert from `index` forward and roll over at `self.capacity`).
            update_indices = tf.range(start=self.index, limit=self.index + num_records) % self.capacity

            # Updates all the necessary sub-variables in the record.
            record_updates = []
            for key in self.memory:
                record_updates.append(self.scatter_update_variable(
                    variable=self.memory[key],
                    indices=update_indices,
                    updates=records[key]
                ))

            # Update indices and size.
            with tf.control_dependencies(control_inputs=record_updates):
                index_updates = [self.assign_variable(ref=self.index, value=(self.index + num_records) % self.capacity)]
                update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity)
                index_updates.append(self.assign_variable(self.size, value=update_size))

            # Nothing to return.
            with tf.control_dependencies(control_inputs=index_updates):
                return tf.no_op()
        elif get_backend() == "pytorch":
            update_indices = torch.arange(self.index, self.index + num_records) % self.capacity
            for key in self.memory:
                for i, val in zip(update_indices, records[key]):
                    self.memory[key][i] = val
            self.index = (self.index + num_records) % self.capacity
            self.size = min(self.size + num_records, self.capacity)
            return None
Beispiel #2
0
    def _graph_fn_update_records(self, indices, update):
        num_records = get_batch_size(indices)
        max_priority = 0.0

        # Update has to be sequential.
        def insert_body(i, max_priority_):
            priority = tf.pow(x=update[i], y=self.alpha)

            sum_insert = self.sum_segment_tree.insert(index=indices[i],
                                                      element=priority,
                                                      insert_op=tf.add)
            min_insert = self.min_segment_tree.insert(index=indices[i],
                                                      element=priority,
                                                      insert_op=tf.minimum)
            # Keep track of current max priority element.
            max_priority_ = tf.maximum(x=max_priority_, y=priority)

            with tf.control_dependencies(
                    control_inputs=[tf.group(sum_insert, min_insert)]):
                # TODO: This confuses the auto-return value detector.
                return i + 1, max_priority_

        def cond(i, max_priority_):
            return i < num_records - 1

        _, max_priority = tf.while_loop(cond=cond,
                                        body=insert_body,
                                        loop_vars=(0, max_priority))

        assignment = self.assign_variable(ref=self.max_priority,
                                          value=max_priority)
        with tf.control_dependencies(control_inputs=[assignment]):
            return tf.no_op()
Beispiel #3
0
    def _graph_fn_insert_records(self, records):
        num_records = get_batch_size(records["terminals"])
        # List of indices to update (insert from `index` forward and roll over at `self.capacity`).
        update_indices = tf.range(start=self.index, limit=self.index + num_records) % self.capacity

        # Updates all the necessary sub-variables in the record.
        # update_indices = tf.Print(update_indices, [update_indices, index, num_records], summarize=100,
        #                           message='Update indices / index / num records = ')
        record_updates = list()
        for key in self.record_registry:
            record_updates.append(self.scatter_update_variable(
                variable=self.record_registry[key],
                indices=update_indices,
                updates=records[key]
            ))

        # Update indices and size.
        with tf.control_dependencies(control_inputs=record_updates):
            index_updates = list()
            index_updates.append(self.assign_variable(ref=self.index, value=(self.index + num_records) % self.capacity))
            update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity)
            index_updates.append(self.assign_variable(self.size, value=update_size))

        # Nothing to return.
        with tf.control_dependencies(control_inputs=index_updates):
            return tf.no_op()
Beispiel #4
0
    def _graph_fn_sample(self, sample_size, inputs):
        """
        Takes a set of input tensors and uniformly samples a subset of the
        specified size from them.

        Args:
            sample_size (SingleDataOp[int]): Subsample size.
            inputs (FlattenedDataOp): Input tensors (in a FlattenedDataOp) to sample from.
                All values (tensors) should all be the same size.

        Returns:
            FlattenedDataOp: The sub-sampled api_methods (will be unflattened automatically).
        """
        batch_size = get_batch_size(next(iter(inputs.values())))

        if get_backend() == "tf":
            sample_indices = tf.random_uniform(shape=(sample_size, ),
                                               maxval=batch_size,
                                               dtype=tf.int32)
            sample = FlattenedDataOp()
            for key, tensor in inputs.items():
                sample[key] = tf.gather(params=tensor, indices=sample_indices)
            return sample
Beispiel #5
0
    def _graph_fn_insert_records(self, records):
        num_records = get_batch_size(records["terminals"])
        index = self.read_variable(self.index)
        update_indices = tf.range(start=index,
                                  limit=index + num_records) % self.capacity

        # Updates all the necessary sub-variables in the record.
        record_updates = list()
        for key in self.record_registry:
            record_updates.append(
                self.scatter_update_variable(
                    variable=self.record_registry[key],
                    indices=update_indices,
                    updates=records[key]))

        # Update indices and size.
        with tf.control_dependencies(control_inputs=record_updates):
            index_updates = list()
            index_updates.append(
                self.assign_variable(ref=self.index,
                                     value=(index + num_records) %
                                     self.capacity))
            update_size = tf.minimum(x=(self.read_variable(self.size) +
                                        num_records),
                                     y=self.capacity)
            index_updates.append(
                self.assign_variable(self.size, value=update_size))

        weight = tf.pow(x=self.max_priority, y=self.alpha)

        # Insert new priorities into segment tree.
        def insert_body(i):
            sum_insert = self.sum_segment_tree.insert(update_indices[i],
                                                      weight, tf.add)
            with tf.control_dependencies(control_inputs=[sum_insert]):
                return i + 1

        def cond(i):
            return i < num_records

        with tf.control_dependencies(control_inputs=index_updates):
            sum_insert = tf.while_loop(cond=cond,
                                       body=insert_body,
                                       loop_vars=[0])

        def insert_body(i):
            min_insert = self.min_segment_tree.insert(update_indices[i],
                                                      weight, tf.minimum)
            with tf.control_dependencies(control_inputs=[min_insert]):
                return i + 1

        def cond(i):
            return i < num_records

        with tf.control_dependencies(control_inputs=[sum_insert]):
            min_insert = tf.while_loop(cond=cond,
                                       body=insert_body,
                                       loop_vars=[0])

        # Nothing to return.
        with tf.control_dependencies(control_inputs=[min_insert]):
            return tf.no_op()
Beispiel #6
0
    def _graph_fn_insert_records(self, records):
        if get_backend() == "tf":
            num_records = get_batch_size(records["terminals"])
            index = self.read_variable(self.index)

            # Episodes before inserting these records.
            prev_num_episodes = self.read_variable(self.num_episodes)
            update_indices = tf.range(
                start=index, limit=index + num_records) % self.capacity

            # Episodes previously existing in the range we inserted to as indicated
            # by count of terminals in the that slice.
            insert_terminal_slice = self.read_variable(
                self.record_registry['terminals'], update_indices)

            # Shift episode indices.
            with tf.control_dependencies([
                    update_indices, index, prev_num_episodes,
                    insert_terminal_slice
            ]):
                index_updates = []

                # Newly inserted episodes.
                inserted_episodes = tf.reduce_sum(input_tensor=tf.cast(
                    records['terminals'], dtype=tf.int32),
                                                  axis=0)
                episodes_in_insert_range = tf.reduce_sum(input_tensor=tf.cast(
                    insert_terminal_slice, dtype=tf.int32),
                                                         axis=0)
                num_episode_update = prev_num_episodes - episodes_in_insert_range + inserted_episodes

                # Shift contiguous episode indices.
                index_updates.append(
                    self.assign_variable(
                        ref=self.episode_indices[:prev_num_episodes -
                                                 episodes_in_insert_range],
                        value=self.episode_indices[
                            episodes_in_insert_range:prev_num_episodes]))

                # Insert new episodes starting at previous count minus the ones we removed,
                # ending at previous count minus removed + inserted.
                slice_start = prev_num_episodes - episodes_in_insert_range
                slice_end = num_episode_update
                # update_indices = tf.Print(update_indices, [update_indices, tf.shape(update_indices)],
                #                           summarize=100, message='\n update indices / shape = ')

            # Update indices and size.
            with tf.control_dependencies(index_updates):
                index_updates = []

                # Actually update indices.
                mask = tf.boolean_mask(tensor=update_indices,
                                       mask=records['terminals'])
                index_updates.append(
                    self.assign_variable(
                        ref=self.episode_indices[slice_start:slice_end],
                        value=mask))

                # Assign final new episode count.
                index_updates.append(
                    self.assign_variable(self.num_episodes,
                                         num_episode_update))

                index_updates.append(
                    self.assign_variable(ref=self.index,
                                         value=(index + num_records) %
                                         self.capacity))
                update_size = tf.minimum(x=(self.read_variable(self.size) +
                                            num_records),
                                         y=self.capacity)
                index_updates.append(
                    self.assign_variable(self.size, value=update_size))

            # Updates all the necessary sub-variables in the record.
            with tf.control_dependencies(index_updates):
                record_updates = []
                for key in self.record_registry:
                    record_updates.append(
                        self.scatter_update_variable(
                            variable=self.record_registry[key],
                            indices=update_indices,
                            updates=records[key]))

            # Nothing to return.
            with tf.control_dependencies(control_inputs=record_updates):
                return tf.no_op()
        elif get_backend() == "pytorch":
            # TODO: Unclear if we should do this in numpy and then convert to torch once we sample.
            num_records = get_batch_size(records["terminals"])
            update_indices = torch.arange(
                self.index, self.index + num_records) % self.capacity

            # Newly inserted episodes.
            inserted_episodes = torch.sum(records['terminals'].int(), 0)

            # Episodes previously existing in the range we inserted to as indicated
            # by count of terminals in the that slice.
            episodes_in_insert_range = 0
            # Count terminals in inserted range.
            for index in update_indices:
                episodes_in_insert_range += int(
                    self.record_registry["terminals"][index])
            num_episode_update = self.num_episodes - episodes_in_insert_range + inserted_episodes
            self.episode_indices[:self.num_episodes - episodes_in_insert_range] = \
                self.episode_indices[episodes_in_insert_range:self.num_episodes]

            # Insert new episodes starting at previous count minus the ones we removed,
            # ending at previous count minus removed + inserted.
            slice_start = self.num_episodes - episodes_in_insert_range
            slice_end = num_episode_update

            byte_terminals = records["terminals"].byte()
            mask = torch.masked_select(update_indices, byte_terminals)
            self.episode_indices[slice_start:slice_end] = mask

            # Update indices.
            self.num_episodes = int(num_episode_update)
            self.index = (self.index + num_records) % self.capacity
            self.size = min(self.size + num_records, self.capacity)

            # Updates all the necessary sub-variables in the record.
            for key in self.record_registry:
                for i, val in zip(update_indices, records[key]):
                    self.record_registry[key][i] = val

            # The TF version returns no-op, return None so return-val inference system does not throw error.
            return None
Beispiel #7
0
    def _graph_fn_insert_records(self, records):
        num_records = get_batch_size(records["terminals"])
        index = self.read_variable(self.index)
        update_indices = tf.range(start=index, limit=index + num_records) % self.capacity

        # update_indices = tf.Print(update_indices, [index, num_records, update_indices],
        #  summarize=100, message='index|num|indices')
        # update_indices = tf.Print(update_indices, [tf.shape(update_indices),
        #                                            tf.shape(records["terminals"])],
        #                           summarize=100, message='shape indices|shape recods')
        # Update indices and size.
        with tf.control_dependencies([update_indices]):
            index_updates = []
            if self.episode_semantics:
                # Episodes before inserting these records.
                prev_num_episodes = self.read_variable(self.num_episodes)

                # Newly inserted episodes.
                inserted_episodes = tf.reduce_sum(input_tensor=tf.cast(records['terminals'], dtype=tf.int32), axis=0)

                # Episodes previously existing in the range we inserted to as indicated
                # by count of terminals in the that slice.
                insert_terminal_slice = self.read_variable(self.record_registry['terminals'], update_indices)
                episodes_in_insert_range = tf.reduce_sum(
                    input_tensor=tf.cast(insert_terminal_slice, dtype=tf.int32), axis=0
                )

                # prev_num_episodes = tf.Print(prev_num_episodes, [
                #     prev_num_episodes,
                #     episodes_in_insert_range,
                #     inserted_episodes],
                #     summarize=100, message='previous num eps / prev episodes in insert range / inserted eps = '
                # )
                num_episode_update = prev_num_episodes - episodes_in_insert_range + inserted_episodes

                # prev_num_episodes = tf.Print(prev_num_episodes, [prev_num_episodes, episodes_in_insert_range],
                #                             summarize=100, message='num eps, eps in insert range =')
                # Remove previous episodes in inserted range.
                index_updates.append(self.assign_variable(
                        ref=self.episode_indices[:prev_num_episodes + 1 - episodes_in_insert_range],
                        value=self.episode_indices[episodes_in_insert_range:prev_num_episodes + 1]
                ))

                # Insert new episodes starting at previous count minus the ones we removed,
                # ending at previous count minus removed + inserted.
                slice_start = prev_num_episodes - episodes_in_insert_range
                slice_end = num_episode_update
                # update_indices = tf.Print(update_indices, [update_indices, tf.shape(update_indices)],
                #                           summarize=100, message='\n update indices / shape = ')
                # slice_start = tf.Print(
                #     slice_start, [slice_start, slice_end, self.episode_indices],
                #     summarize=100,
                #     message='\n slice start/ slice end / episode indices before = '
                # )

                with tf.control_dependencies(index_updates):
                    index_updates = []
                    mask = tf.boolean_mask(tensor=update_indices, mask=records['terminals'])
                    # mask = tf.Print(mask, [mask, update_indices, records['/terminals']], summarize=100,
                    #     message='\n mask /  update indices / records-terminal')

                    index_updates.append(self.assign_variable(
                        ref=self.episode_indices[slice_start:slice_end],
                        value=mask
                    ))
                    # num_episode_update = tf.Print(num_episode_update, [num_episode_update, self.episode_indices],
                    #     summarize=100,  message='\n num episodes / episode indices after: ')

                    # Assign final new episode count.
                    index_updates.append(self.assign_variable(self.num_episodes, num_episode_update))

            index_updates.append(self.assign_variable(ref=self.index, value=(index + num_records) % self.capacity))
            update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity)
            index_updates.append(self.assign_variable(self.size, value=update_size))

        # Updates all the necessary sub-variables in the record.
        with tf.control_dependencies(index_updates):
            record_updates = []
            for key in self.record_registry:
                record_updates.append(self.scatter_update_variable(
                    variable=self.record_registry[key],
                    indices=update_indices,
                    updates=records[key]
                ))

        # Nothing to return.
        with tf.control_dependencies(control_inputs=record_updates):
            return tf.no_op()