Exemplo n.º 1
0
    def get_action(self,
                   states,
                   internals=None,
                   use_exploration=True,
                   apply_preprocessing=True,
                   extra_returns=None):
        # TODO: common pattern - move to Agent
        """
        Args:
            extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return
                values (besides the actions). Possible values are:
                - 'preprocessed_states': The preprocessed states after passing the given states through the
                preprocessor stack.
                - 'internal_states': The internal states returned by the RNNs in the NN pipeline.
                - 'used_exploration': Whether epsilon- or noise-based exploration was used or not.

        Returns:
            tuple or single value depending on `extra_returns`:
                - action
                - the preprocessed states
        """
        extra_returns = {extra_returns} if isinstance(
            extra_returns, str) else (extra_returns or set())
        # States come in without preprocessing -> use state space.
        if apply_preprocessing:
            call_method = self.root_component.get_preprocessed_state_and_action
            batched_states = self.state_space.force_batch(states)
        else:
            call_method = self.root_component.action_from_preprocessed_state
            batched_states = states
        remove_batch_rank = batched_states.ndim == np.asarray(states).ndim + 1

        # Increase timesteps by the batch size (number of states in batch).
        batch_size = len(batched_states)
        self.timesteps += batch_size

        # Control, which return value to "pull" (depending on `additional_returns`).
        return_ops = [0, 1] if "preprocessed_states" in extra_returns else [0]
        ret = force_list(
            self.graph_executor.execute((
                call_method,
                [batched_states,
                 not use_exploration],  # deterministic = not use_exploration
                # 0=preprocessed_states, 1=action
                return_ops)))
        # Convert Gumble (relaxed one-hot) sample back into int type for all discrete composite actions.
        if isinstance(self.action_space, ContainerSpace):
            ret[0] = ret[0].map(mapping=lambda key, action: np.argmax(
                action, axis=-1).astype(action.dtype) if isinstance(
                    self.flat_action_space[key], IntBox) else action)
        elif isinstance(self.action_space, IntBox):
            ret[0] = np.argmax(ret[0], axis=-1).astype(self.action_space.dtype)

        if remove_batch_rank:
            ret[0] = strip_list(ret[0])

        if "preprocessed_states" in extra_returns:
            return ret[0], ret[1]
        else:
            return ret[0]
Exemplo n.º 2
0
    def get_action(self,
                   states,
                   internals=None,
                   use_exploration=True,
                   apply_preprocessing=True,
                   extra_returns=None,
                   time_percentage=None):
        """
        Args:
            extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return
                values (besides the actions). Possible values are:
                - 'preprocessed_states': The preprocessed states after passing the given states through the
                preprocessor stack.
                - 'internal_states': The internal states returned by the RNNs in the NN pipeline.
                - 'used_exploration': Whether epsilon- or noise-based exploration was used or not.

        Returns:
            tuple or single value depending on `extra_returns`:
                - action
                - the preprocessed states
        """
        # TODO: Move update_spec to Worker. Agent should not hold these execution details.
        if time_percentage is None:
            time_percentage = self.timesteps / self.update_spec.get(
                "max_timesteps", 1e6)

        extra_returns = {extra_returns} if isinstance(
            extra_returns, str) else (extra_returns or set())
        # States come in without preprocessing -> use state space.
        if apply_preprocessing:
            call_method = "get_preprocessed_state_and_action"
            batched_states, remove_batch_rank = self.state_space.force_batch(
                states)
        else:
            call_method = "action_from_preprocessed_state"
            batched_states = states
            remove_batch_rank = False  #batched_states.ndim == np.asarray(states).ndim + 1

        # Increase timesteps by the batch size (number of states in batch).
        batch_size = len(batched_states)
        self.timesteps += batch_size

        # Control, which return value to "pull" (depending on `additional_returns`).
        return_ops = [0, 1] if "preprocessed_states" in extra_returns else [
            0
        ]  # 1=preprocessed_states, 0=action
        ret = self.graph_executor.execute(
            (call_method, [batched_states, time_percentage,
                           use_exploration], return_ops))
        if remove_batch_rank:
            return strip_list(ret)
        else:
            return ret
Exemplo n.º 3
0
    def get_action(self,
                   states,
                   internals=None,
                   use_exploration=True,
                   apply_preprocessing=True,
                   extra_returns=None):
        """
        Args:
            extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return
                values (besides the actions). Possible values are:
                - 'preprocessed_states': The preprocessed states after passing the given states through the
                preprocessor stack.
                - 'internal_states': The internal states returned by the RNNs in the NN pipeline.
                - 'used_exploration': Whether epsilon- or noise-based exploration was used or not.

        Returns:
            tuple or single value depending on `extra_returns`:
                - action
                - the preprocessed states
        """
        extra_returns = {extra_returns} if isinstance(
            extra_returns, str) else (extra_returns or set())
        # States come in without preprocessing -> use state space.
        if apply_preprocessing:
            call_method = "get_preprocessed_state_and_action"
            batched_states = self.state_space.force_batch(states)
        else:
            call_method = "action_from_preprocessed_state"
            batched_states = states
        remove_batch_rank = batched_states.ndim == np.asarray(states).ndim + 1

        # Increase timesteps by the batch size (number of states in batch).
        batch_size = len(batched_states)
        self.timesteps += batch_size

        # Control, which return value to "pull" (depending on `additional_returns`).
        return_ops = [1, 0] if "preprocessed_states" in extra_returns else [1]
        ret = self.graph_executor.execute(
            (
                call_method,
                [batched_states, self.timesteps, use_exploration],
                # 0=preprocessed_states, 1=action
                return_ops)
        )  #, flip_batch_with_dict_keys=isinstance(self.action_space, ContainerSpace))
        if remove_batch_rank:
            return strip_list(ret)
        else:
            return ret
Exemplo n.º 4
0
    def get_action(self, states, internals=None, use_exploration=True, apply_preprocessing=True, extra_returns=None,
                   time_percentage=None):
        """
        Args:
            extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return
                values (besides the actions). Possible values are:
                - 'preprocessed_states': The preprocessed states after passing the given states through the
                preprocessor stack.
                - 'internal_states': The internal states returned by the RNNs in the NN pipeline.
                - 'used_exploration': Whether epsilon- or noise-based exploration was used or not.

        Returns:
            tuple or single value depending on `extra_returns`:
                - action
                - the preprocessed states
        """
        extra_returns = {extra_returns} if isinstance(extra_returns, str) else (extra_returns or set())
        # States come in without preprocessing -> use state space.
        if apply_preprocessing:
            call_method = "get_preprocessed_state_and_action"
            batched_states, remove_batch_rank = self.state_space.force_batch(states, horizontal=False)
        # States are already pre-processed (and therefore also batched).
        else:
            call_method = "action_from_preprocessed_state"
            batched_states = states
            remove_batch_rank = False

        # Increase timesteps by the batch size (number of states in batch).
        if not isinstance(batched_states, (dict, tuple)):
            batch_size = len(batched_states)
        elif isinstance(batched_states, dict):
            batch_size = len(batched_states[next(iter(batched_states))])
        else:
            batch_size = len(next(iter(batched_states)))
        self.timesteps += batch_size

        # Control, which return value to "pull" (depending on `additional_returns`).
        #return_ops = [0, 1, 2, 3, 4, 5, 6] if "preprocessed_states" in extra_returns else [0, 2, 3, 4, 5,
        # 6]  # 1=preprocessed_states, 0=action
        return_ops = [0, 1] if "preprocessed_states" in extra_returns else [0]  # 1=preprocessed_states, 0=action
        ret = self.graph_executor.execute((
            call_method,
            [batched_states, not use_exploration],  # deterministic = not use_exploration
            # 0=preprocessed_states, 1=action
            return_ops
        ))

        # Print out distribution parameters for the categorical `direction` distribution.
        #print("-------")
        #print("State: {}".format(states[0]["yz_location"]))
        #print("Action: {}".format(ret[0]["direction"]))
        #print("Direction paramsdsds:" + str(ret[3]["direction"]))
        #print("Action probs direction:" + str(ret[4]["direction"]))
        #print("Action log probs direction:" + str(ret[5]["direction"]))
        #print("Crouch params:" sdsd+ str(ret[3]["crouch"]))
        #print("Action probs crouch:" + str(ret[4]["crouch"]))
        #print("Action log probs crouch:" + str(ret[5]["crouch"]))
        #print("Jump params:" + str(ret[3]["jump"]))
        #print("Action probs jump:" + str(ret[4]["jump"]))
        #print("Action log probs jump:" + str(ret[5]["jump"]))
        #print("-------")

        # If unbatched data came in, return unbatched data.
        if remove_batch_rank:
            if list(ret[1][0][-3:]) == [0,0,0]:
                self.prev_action = -1
            if ret[0][0] == self.prev_action:
                valActions = [i for i in [0, 1, 2, 3] if i != self.prev_action]
                ret_new = (np.array([valActions[np.random.randint(0, len(valActions))]],dtype=np.int), ret[1])
            else:
                ret_new = ret
            #return ret[0]
            self.prev_action = ret_new[0][0]
            #return strip_list(ret[0])
            return strip_list(ret_new)
        # Return batched data.
        else:
            if list(ret[1][0][-3:]) == [0,0,0]:
                self.prev_action = -1
            if ret[0][0] == self.prev_action:
                valActions = [i for i in [0, 1, 2, 3] if i != self.prev_action]
                ret_new = (np.array([valActions[np.random.randint(0, len(valActions))]],dtype=np.int), ret[1])
            else:
                ret_new = ret
            #return ret[0]
            self.prev_action = ret_new[0][0]
            return ret_new
Exemplo n.º 5
0
    def get_action(self,
                   states,
                   internals=None,
                   use_exploration=True,
                   apply_preprocessing=True,
                   extra_returns=None,
                   time_percentage=None):
        """
        Args:
            extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return
                values (besides the actions). Possible values are:
                - 'preprocessed_states': The preprocessed states after passing the given states through the
                preprocessor stack.
                - 'internal_states': The internal states returned by the RNNs in the NN pipeline.
                - 'used_exploration': Whether epsilon- or noise-based exploration was used or not.

        Returns:
            tuple or single value depending on `extra_returns`:
                - action
                - the preprocessed states
        """
        extra_returns = {extra_returns} if isinstance(
            extra_returns, str) else (extra_returns or set())
        # States come in without preprocessing -> use state space.
        if apply_preprocessing:
            call_method = "get_preprocessed_state_and_action"
            batched_states, remove_batch_rank = self.state_space.force_batch(
                states, horizontal=False)
        # States are already pre-processed (and therefore also batched).
        else:
            call_method = "action_from_preprocessed_state"
            batched_states = states
            remove_batch_rank = False

        # Increase timesteps by the batch size (number of states in batch).
        if not isinstance(batched_states, (dict, tuple)):
            batch_size = len(batched_states)
        elif isinstance(batched_states, dict):
            batch_size = len(batched_states[next(iter(batched_states))])
        else:
            batch_size = len(next(iter(batched_states)))
        self.timesteps += batch_size

        # Control, which return value to "pull" (depending on `additional_returns`).
        #return_ops = [0, 1, 2, 3, 4, 5, 6] if "preprocessed_states" in extra_returns else [0, 2, 3, 4, 5,
        # 6]  # 1=preprocessed_states, 0=action
        return_ops = [0, 1] if "preprocessed_states" in extra_returns else [
            0
        ]  # 1=preprocessed_states, 0=action
        ret = self.graph_executor.execute((
            call_method,
            [batched_states,
             not use_exploration],  # deterministic = not use_exploration
            # 0=preprocessed_states, 1=action
            return_ops))

        # If unbatched data came in, return unbatched data.
        if remove_batch_rank:
            return strip_list(ret)
        # Return batched data.
        else:
            return ret