Exemple #1
0
    def _graph_fn_get_records(self, num_records=1):
        if get_backend() == "tf":
            size = self.read_variable(self.size)

            # Sample and retrieve a random range, including terminals.
            index = self.read_variable(self.index)
            indices = tf.random_uniform(shape=(num_records,), maxval=size, dtype=tf.int32)
            indices = (index - 1 - indices) % self.capacity

            # Return default importance weight one.
            return self._read_records(indices=indices), indices, tf.ones_like(tensor=indices, dtype=tf.float32)
        elif get_backend() == "pytorch":
            indices = []
            if self. size > 0:
                indices = np.random.choice(np.arange(0, self.size), size=int(num_records))
                indices = (self.index - 1 - indices) % self.capacity
            records = OrderedDict()
            for name, variable in self.memory.items():
                records[name] = self.read_variable(variable, indices, dtype=
                                                   util.convert_dtype(self.flat_record_space[name].dtype, to="pytorch"),
                                                   shape=self.flat_record_space[name].shape)
            records = define_by_run_unflatten(records)
            weights = torch.ones(indices.shape, dtype=torch.float32) if len(indices) > 0 \
                else torch.ones(1, dtype=torch.float32)
            return records, indices, weights
Exemple #2
0
    def _graph_fn_get_records(self, num_records=1):
        available_records = min(num_records, self.size)
        indices = []
        prob_sum = self.merged_segment_tree.sum_segment_tree.get_sum(0, self.size - 1)
        samples = np.random.random(size=(available_records,)) * prob_sum
        for sample in samples:
            indices.append(self.merged_segment_tree.sum_segment_tree.index_of_prefixsum(prefix_sum=sample))

        sum_prob = self.merged_segment_tree.sum_segment_tree.get_sum() + SMALL_NUMBER
        min_prob = self.merged_segment_tree.min_segment_tree.get_min_value() / sum_prob
        max_weight = (min_prob * self.size) ** (-self.beta)
        weights = []
        for index in indices:
            sample_prob = self.merged_segment_tree.sum_segment_tree.get(index) / sum_prob
            weight = (sample_prob * self.size) ** (-self.beta)
            weights.append(weight / max_weight)

        if get_backend() == "pytorch":
            indices = torch.tensor(indices)
            weights = torch.tensor(weights)
        else:
            indices = np.asarray(indices)
            weights = np.asarray(weights)

        records = OrderedDict()
        for name, variable in self.record_registry.items():
            records[name] = self.read_variable(variable, indices, dtype=
            util.convert_dtype(self.flat_record_space[name].dtype, to="pytorch"))
        records = define_by_run_unflatten(records)
        return records, indices, weights
Exemple #3
0
    def _graph_fn_get_action_components(self, logits, parameters,
                                        deterministic):
        ret = {}

        # TODO Clean up the checks in here wrt define-by-run processing.
        for flat_key, action_space_component in self.action_space.flatten(
        ).items():
            # Skip our distribution, iff discrete action-space and deterministic acting (greedy).
            # In that case, one does not need to create a distribution in the graph each act (only to get the argmax
            # over the logits, which is the same as the argmax over the probabilities (or log-probabilities)).
            if isinstance(action_space_component, IntBox) and \
                    (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)):
                if flat_key == "":
                    return self._graph_fn_get_deterministic_action_wo_distribution(
                        logits)
                else:
                    ret[flat_key] = self._graph_fn_get_deterministic_action_wo_distribution(
                        logits.flat_key_lookup(flat_key))
            elif isinstance(action_space_component, BoolBox) and \
                    (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)):
                if get_backend() == "tf":
                    if flat_key == "":
                        return tf.greater(logits, 0.5)
                    else:
                        ret[flat_key] = tf.greater(
                            logits.flat_key_lookup(flat_key), 0.5)
                elif get_backend() == "pytorch":
                    if flat_key == "":
                        return torch.gt(logits, 0.5)
                    else:
                        ret[flat_key] = torch.gt(
                            logits.flat_key_lookup(flat_key), 0.5)
            else:
                if flat_key == "":
                    # Still wrapped as FlattenedDataOp.
                    if isinstance(parameters, FlattenedDataOp):
                        return self.distributions[flat_key].draw(
                            parameters[flat_key], deterministic)
                    else:
                        return self.distributions[flat_key].draw(
                            parameters, deterministic)

                if isinstance(parameters, ContainerDataOp) and not \
                        (isinstance(parameters, DataOpDict) and flat_key in parameters):
                    ret[flat_key] = self.distributions[flat_key].draw(
                        parameters.flat_key_lookup(flat_key), deterministic)
                else:
                    ret[flat_key] = self.distributions[flat_key].draw(
                        parameters[flat_key], deterministic)

        if get_backend() == "tf":
            return unflatten_op(ret)
        elif get_backend() == "pytorch":
            return define_by_run_unflatten(ret)
Exemple #4
0
    def _graph_fn_get_episodes(self, num_episodes=1):
        if get_backend() == "tf":
            stored_episodes = self.read_variable(self.num_episodes)
            available_episodes = tf.minimum(x=num_episodes, y=stored_episodes)

            # Say we have two episodes with this layout:
            # terminals = [0 0 1 0 1]
            # episode_indices = [2, 4]
            # If we want to fetch the most recent episode, the start index is:
            # stored_episodes - 1 - num_episodes = 2 - 1 - 1 = 0, which points to buffer index 2
            # The next episode starts one element after this, hence + 1.
            # However, this points to index -1 if stored_episodes = available_episodes,
            # in this case we want start = 0 to get everything.
            start = tf.cond(pred=tf.equal(x=stored_episodes,
                                          y=available_episodes),
                            true_fn=lambda: 0,
                            false_fn=lambda: self.episode_indices[
                                stored_episodes - available_episodes - 1] + 1)
            # End index is just the pointer to the most recent episode.
            limit = self.episode_indices[stored_episodes - 1]

            limit += tf.where(condition=(start < limit),
                              x=0,
                              y=self.capacity - 1)
            # limit = tf.Print(limit, [stored_episodes, start, limit], summarize=100, message="start | limit")
            indices = tf.range(start=start, limit=limit + 1) % self.capacity
            return self._read_records(indices=indices)
        elif get_backend() == "pytorch":
            stored_episodes = self.num_episodes
            available_episodes = min(num_episodes, self.num_episodes)

            if stored_episodes == available_episodes:
                start = 0
            else:
                start = self.episode_indices[stored_episodes -
                                             available_episodes - 1] + 1

            # End index is just the pointer to the most recent episode.
            limit = self.episode_indices[stored_episodes - 1]
            if start >= limit:
                limit += self.capacity - 1
            indices = torch.arange(start, limit + 1) % self.capacity

            records = OrderedDict()
            for name, variable in self.memory.items():
                records[name] = self.read_variable(
                    variable,
                    indices,
                    dtype=util.convert_dtype(
                        self.flat_record_space[name].dtype, to="pytorch"),
                    shape=self.flat_record_space[name].shape)
            records = define_by_run_unflatten(records)
            return records
Exemple #5
0
    def _graph_fn_get_records(self, num_records=1):
        if get_backend() == "tf":
            stored_records = self.read_variable(self.size)
            available_records = tf.minimum(x=num_records, y=stored_records)
            index = self.read_variable(self.index)
            indices = tf.range(start=index - available_records, limit=index) % self.capacity
            return self._read_records(indices=indices)
        elif get_backend() == "pytorch":
            available_records = min(num_records, self.size)
            indices = np.arange(self.index - available_records, self.index) % self.capacity
            records = OrderedDict()

            for name, variable in self.record_registry.items():
                records[name] = self.read_variable(variable, indices, dtype=
                                                   util.convert_dtype(self.flat_record_space[name].dtype, to="pytorch"),
                                                   shape=self.flat_record_space[name].shape)

            records = define_by_run_unflatten(records)
            return records
    def clean_dict(tensor_dict):
        """
        Detach tensor values in nested dict.
        Args:
            tensor_dict (dict): Dict containing torch tensor.

        Returns:
            dict: Dict containing numpy arrays.
        """
        # Un-nest.
        param = define_by_run_flatten(tensor_dict)
        ret = {}

        # Detach tensor values.
        for key, value in param.items():
            if isinstance(value, torch.Tensor):
                ret[key] = value.detach().numpy()

        # Pack again.
        return define_by_run_unflatten(ret)