Exemplo n.º 1
0
    def __add__(self, other):
        """
        Combine two different states together. In this special case, the operation is not commutable.
        This is the same as taking the union of the states.

        Args:
            other (State): another state

        Returns:
            State: the combined state

        Examples:
            s1 = JntPositionState(robot)
            s2 = JntVelocityState(robot)
            s = s1 + s2     # = State([JntPositionState(robot), JntVelocityState(robot)])

            s1 = State([JntPositionState(robot), JntVelocityState(robot)])
            s2 = State([JntPositionState(robot), LinkPositionState(robot)])
            s = s1 + s2     # = State([JntPositionState(robot), JntVelocityState(robot), LinkPositionState(robot)])
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        s1 = self._states if self._data is None else OrderedSet([self])
        s2 = other._states if other._data is None else OrderedSet([other])
        s = s1 + s2
        return State(s)
Exemplo n.º 2
0
    def __add__(self, other):
        """
        Combine two different actions together. In this special case, the operation is not commutable.
        This is the same as taking the union of the actions.

        Args:
            other (Action): another action

        Returns:
            Action: the combined action

        Examples:
            s1 = JntPositionAction(robot)
            s2 = JntVelocityAction(robot)
            s = s1 + s2     # = Action([JntPositionAction(robot), JntVelocityAction(robot)])

            s1 = Action([JntPositionAction(robot), JntVelocityAction(robot)])
            s2 = Action([JntPositionAction(robot), LinkPositionAction(robot)])
            s = s1 + s2     # = Action([JntPositionAction(robot), JntVelocityAction(robot), LinkPositionAction(robot)])
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        s1 = self._actions if self._data is None else OrderedSet([self])
        s2 = other._actions if other._data is None else OrderedSet([other])
        s = s1 + s2
        return Action(s)
Exemplo n.º 3
0
    def __init__(self, actions=(), data=None, space=None, name=None, ticks=1):
        """
        Initialize the action. The action contains some kind of data, or is a combination of other actions.

        Args:
            actions (list/tuple of Action): list of actions to be combined together (if given, we can not specified
                                            data)
            data (np.ndarray): data associated to this action
            space (gym.space): space associated with the given data
            ticks (int): number of ticks to sleep before setting the next action data.

        Warning:
            Both arguments can not be provided to the action.
        """
        # Check arguments
        if actions is None:
            actions = tuple()

        if not isinstance(actions, (list, tuple, set, OrderedSet)):
            raise TypeError(
                "Expecting a list, tuple, or (ordered) set of actions.")
        if len(actions) > 0 and data is not None:
            raise ValueError(
                "Please specify only one of the argument `actions` xor `data`, but not both."
            )

        # Check if data is given
        if data is not None:
            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                elif isinstance(data, (int, float)):
                    data = np.array([data])
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

        # The following attributes should normally be set in the child classes
        self._data = data
        self._torch_data = data if data is None else torch.from_numpy(
            data).float()
        self._space = space
        self._distribution = None  # for sampling
        self._normalizer = None
        self._noiser = None  # for noise
        self.name = name

        # create ordered set which is useful if this action is a combination of multiple actions
        self._actions = OrderedSet()
        if self._data is None:
            self.add(actions)

        # set ticks and counter
        self.cnt = 0
        self.ticks = int(ticks)
Exemplo n.º 4
0
    def __sub__(self, other):
        """
        Remove the other state(s) from the current state.

        Args:
            other (State): state to be removed.
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        s1 = self._states if self._data is None else OrderedSet([self])
        s2 = other._states if other._data is None else OrderedSet([other])
        s = s1 - s2
        if len(s) == 1:  # just one element
            return s[0]
        return State(s)
Exemplo n.º 5
0
    def __sub__(self, other):
        """
        Remove the other action(s) from the current action.

        Args:
            other (Action): action to be removed.
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        s1 = self._actions if self._data is None else OrderedSet([self])
        s2 = other._actions if other._data is None else OrderedSet([other])
        s = s1 - s2
        if len(s) == 1:  # just one element
            return s[0]
        return Action(s)
Exemplo n.º 6
0
    def __isub__(self, other):
        """
        Remove one or several states from the combined state.

        Args:
            other (State): state to be removed.
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        if self._data is not None:
            raise RuntimeError(
                "This operation is only available for a combined state")
        s = other._states if other._data is None else OrderedSet([other])
        self._states -= s
Exemplo n.º 7
0
    def __isub__(self, other):
        """
        Remove one or several actions from the combined action.

        Args:
            other (Action): action to be removed.
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        if self._data is not None:
            raise RuntimeError(
                "This operation is only available for a combined action")
        s = other._actions if other._data is None else OrderedSet([other])
        self._actions -= s
Exemplo n.º 8
0
    def __init__(self,
                 states=(),
                 data=None,
                 space=None,
                 window_size=1,
                 axis=None,
                 ticks=1,
                 name=None):
        """
        Initialize the state. The state contains some kind of data, or is a state combined of other states.

        Args:
            states (list/tuple of State): list of states to be combined together (if given, we can not specified data)
            data (np.array): data associated to this state
            space (gym.space): space associated with the given data
            window_size (int): window size of the state. This is the total number of states we should remember. That
                is, if the user wants to remember the current state :math:`s_t` and the previous state :math:`s_{t-1}`,
                the window size is 2. By default, the :attr:`window_size` is one which means we only remember the
                current state. The window size has to be bigger than 1. If it is below, it will be set automatically
                to 1. The :attr:`window_size` attribute is only valid when the state is not a combination of states,
                but is given some :attr:`data`.
            axis (int, None): axis to concatenate or stack the states in the current window. If you have a state with
                shape (n,), then if the axis is None (by default), it will just concatenate it such that resulting
                state has a shape (n*w,) where w is the window size. If the axis is an integer, then it will just stack
                the states in the specified axis. With the example, for axis=0, the resulting state has a shape of
                (w,n), and for axis=-1 or 1, it will have a shape of (n,w). The :attr:`axis` attribute is only when the
                state is not a combination of states, but is given some :attr:`data`.
            ticks (int): number of ticks to sleep before getting the next state data.
            name (str, None): name of the state. If None, by default, it will have the name of the class.

        Warning:
            Both arguments can not be provided to the state.
        """
        # Check arguments
        if states is None:
            states = tuple()

        # check that the given `states` is a list of states
        if not isinstance(states, (list, tuple, set, OrderedSet)):
            # TODO: should check that states is a list of state, however O(N)
            if data is None and isinstance(
                    states, np.ndarray
            ):  # this is in the case someone calls `State(data)`
                data = states
                states = tuple()
            else:
                raise TypeError(
                    "Expecting a list, tuple, or (ordered) set of states.")

        # check that the list of states and the data are not provided together
        if len(states) > 0 and data is not None:
            raise ValueError(
                "Please specify only one of the argument `states` xor `data`, but not both."
            )

        # Check if the data is given, and convert it to a numpy array if necessary
        if data is not None:
            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                elif isinstance(data, (int, float)):
                    data = np.array([data])
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

        # The following attributes should normally be set in the child classes
        self._data = data
        self._torch_data = data if data is None else torch.from_numpy(
            data).float()
        self._space = space
        self._distribution = None  # for sampling
        self._normalizer = None
        self._noiser = None  # for noise
        self.name = name
        self._training_mode = False

        # create ordered set which is useful if this state is a combination of multiple states
        self._states = OrderedSet()
        if self._data is None:
            self.add(states)

        # set data windows
        self._window, self._torch_window = None, None
        self.window_size = window_size  # this initializes the windows (FIFO queues)
        self.axis = axis

        # set ticks and counter
        self._cnt = 0
        self.ticks = ticks

        # reset state
        self.reset()
Exemplo n.º 9
0
class State(object):
    r"""State class.

    The `State` is returned by the environment and given to the policy. The state might include information
    about the state of one or several objects in the world, including robots.

    It is the main bridge between the robots/objects in the environment and the policy. Specifically, it is given
    as an input to the policy which knows how to feed the state to the learning model. Usually, the user only has to
    instantiate a child of this class, and give it to the policy and environment, and that's it.
    In addition to the policy, the state can be given to a controller, dynamic model, value estimator, reward function,
    and so on.

    To allow our framework to be modular, we favor composition over inheritance [1] leading the state to be decoupled
    from notions such as the environment, policy, rewards, etc. This class also describes the `state_space` which has
    initially been defined in `gym.Env` [2].

    Note that the policy does not represent in a strict sense the robot but more its brain, the sensors and actuators
    are parts of the environments. Note also that any kind of data can be represented with numbers (e.g. binary code).

    Example:

        sim = Bullet()
        robot = Robot(sim)

        # Two ways to initialize states
        states = State([JntPositionState(robot), JntVelocityState(robot)])
        # or
        states = JntPositionState(robot) + JntVelocityState(robot)

        actions = JntPositionAction(robot)

        policy = NNPolicy(states, actions)

    References:
        - [1] "Wikipedia: Composition over Inheritance", https://en.wikipedia.org/wiki/Composition_over_inheritance
        - [2] "OpenAI gym": https://gym.openai.com/   and    https://github.com/openai/gym
    """
    def __init__(self,
                 states=(),
                 data=None,
                 space=None,
                 window_size=1,
                 axis=None,
                 ticks=1,
                 name=None):
        """
        Initialize the state. The state contains some kind of data, or is a state combined of other states.

        Args:
            states (list/tuple of State): list of states to be combined together (if given, we can not specified data)
            data (np.array): data associated to this state
            space (gym.space): space associated with the given data
            window_size (int): window size of the state. This is the total number of states we should remember. That
                is, if the user wants to remember the current state :math:`s_t` and the previous state :math:`s_{t-1}`,
                the window size is 2. By default, the :attr:`window_size` is one which means we only remember the
                current state. The window size has to be bigger than 1. If it is below, it will be set automatically
                to 1. The :attr:`window_size` attribute is only valid when the state is not a combination of states,
                but is given some :attr:`data`.
            axis (int, None): axis to concatenate or stack the states in the current window. If you have a state with
                shape (n,), then if the axis is None (by default), it will just concatenate it such that resulting
                state has a shape (n*w,) where w is the window size. If the axis is an integer, then it will just stack
                the states in the specified axis. With the example, for axis=0, the resulting state has a shape of
                (w,n), and for axis=-1 or 1, it will have a shape of (n,w). The :attr:`axis` attribute is only when the
                state is not a combination of states, but is given some :attr:`data`.
            ticks (int): number of ticks to sleep before getting the next state data.
            name (str, None): name of the state. If None, by default, it will have the name of the class.

        Warning:
            Both arguments can not be provided to the state.
        """
        # Check arguments
        if states is None:
            states = tuple()

        # check that the given `states` is a list of states
        if not isinstance(states, (list, tuple, set, OrderedSet)):
            # TODO: should check that states is a list of state, however O(N)
            if data is None and isinstance(
                    states, np.ndarray
            ):  # this is in the case someone calls `State(data)`
                data = states
                states = tuple()
            else:
                raise TypeError(
                    "Expecting a list, tuple, or (ordered) set of states.")

        # check that the list of states and the data are not provided together
        if len(states) > 0 and data is not None:
            raise ValueError(
                "Please specify only one of the argument `states` xor `data`, but not both."
            )

        # Check if the data is given, and convert it to a numpy array if necessary
        if data is not None:
            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                elif isinstance(data, (int, float)):
                    data = np.array([data])
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

        # The following attributes should normally be set in the child classes
        self._data = data
        self._torch_data = data if data is None else torch.from_numpy(
            data).float()
        self._space = space
        self._distribution = None  # for sampling
        self._normalizer = None
        self._noiser = None  # for noise
        self.name = name
        self._training_mode = False

        # create ordered set which is useful if this state is a combination of multiple states
        self._states = OrderedSet()
        if self._data is None:
            self.add(states)

        # set data windows
        self._window, self._torch_window = None, None
        self.window_size = window_size  # this initializes the windows (FIFO queues)
        self.axis = axis

        # set ticks and counter
        self._cnt = 0
        self.ticks = ticks

        # reset state
        self.reset()

    ##############################
    # Properties (Getter/Setter) #
    ##############################

    @property
    def states(self):
        """
        Get the list of states.
        """
        return self._states

    @states.setter
    def states(self, states):
        """
        Set the list of states.
        """
        if self.has_data():
            raise AttributeError(
                "Trying to add internal states to the current state while it already has some data. "
                "A state should be a combination of states or should contain some kind of data, "
                "but not both.")
        if isinstance(states, collections.Iterable):
            for state in states:
                if not isinstance(state, State):
                    raise TypeError(
                        "One of the given states is not an instance of State.")
                self.add(state)
        else:
            raise TypeError(
                "Expecting an iterator (e.g. list, tuple, OrderedSet, set,...) over states"
            )

    @property
    def data(self):
        """
        Get the data associated to this particular state, or the combined data associated to each state.

        Returns:
            list of np.ndarray: list of data associated to the state
        """
        # if the current state has data
        if self.has_data():
            if len(self.window) == 1:
                return [self.window[0]]  # [self._data]

            # concatenate the data in the window
            if self.axis is None:
                return [np.concatenate(self.window.queue)]

            # stack the data in the window
            return [np.stack(self.window.queue, axis=self.axis)]  # stack

        # if multiple states, return the combined data associated to each state
        return [state.data[0] for state in self._states]

    @data.setter
    def data(self, data):
        """
        Set the data associated to this particular state, or the combined data associated to each state.
        Each data will be clipped if outside the range/bounds of the corresponding state.

        Args:
            data: the data to set
        """
        if self.has_states():  # combined states
            if not isinstance(data, collections.Iterable):
                raise TypeError("data is not an iterator")
            if len(self._states) != len(data):
                raise ValueError(
                    "The number of states is different from the number of data segments"
                )
            for state, d in zip(self._states, data):
                state.data = d

        # one state: change the data
        # if self.has_data():
        else:
            # make sure data is a numpy array
            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                    if len(
                            data
                    ) == 1 and self._data.shape != data.shape:  # TODO: check this line
                        data = data[0]
                elif isinstance(data, (int, float)):
                    data = data * np.ones(self._data.shape)
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

            # if previous data shape is different from the current one
            if self._data is not None and self._data.shape != data.shape:
                raise ValueError(
                    "The given data does not have the same shape as previously."
                )

            # clip the value using the space
            if self.has_space():
                if self.is_continuous():  # continuous case
                    low, high = self._space.low, self._space.high
                    data = np.clip(data, low, high)
                else:  # discrete case
                    n = self._space.n
                    if data.size == 1:
                        data = np.clip(data, 0, n)

            # set data
            self._data = data
            self._torch_data = torch.from_numpy(data).float()
            self.window.append(self._data)
            self.torch_window.append(self._torch_data)

            # check that the window is full: if not, copy the last data
            if len(self.window) != self.window.maxsize:
                for _ in range(self.window.maxsize - len(self.window)):
                    # copy last appended data
                    self.window.append(self.window[-1])
                    self.torch_window.append(self.torch_window[-1])

    @property
    def merged_data(self):
        """
        Return the merged data.
        """
        # fuse the data
        fused_state = self.fuse()
        # return the data
        return fused_state.data

    @property
    def last_data(self):
        """Return the last provided data."""
        if self.has_data():
            return self.window[-1]
        return [state.last_data for state in self._states]

    @property
    def torch_data(self):
        """
        Return the data as a list of torch tensors.
        """
        if self.has_data():
            if len(self.torch_window) == 1:
                return [self.torch_window[0]]  # [self._torch_data]

            # concatenate the data in the window
            if self.axis is None:
                return [torch.cat(self.torch_window.tolist())]

            # stack the data in the window
            return [torch.stack(self.torch_window.tolist(),
                                dim=self.axis)]  # stack

        return [state.torch_data[0] for state in self._states]

    @torch_data.setter
    def torch_data(self, data):
        """
        Set the torch data and update the numpy version of the data.

        Args:
            data (torch.Tensor, list of torch.Tensors): data to set.
        """
        if self.has_states():  # combined states
            if not isinstance(data, collections.Iterable):
                raise TypeError("data is not an iterator")
            if len(self._states) != len(data):
                raise ValueError(
                    "The number of states is different from the number of data segments"
                )
            for state, d in zip(self._states, data):
                state.torch_data = d

        # one state: change the data
        # if self.has_data():
        else:
            if isinstance(data, torch.Tensor):
                data = data.float()
            elif isinstance(data, np.ndarray):
                data = torch.from_numpy(data).float()
            elif isinstance(data, (list, tuple)):
                data = torch.from_numpy(np.array(data)).float()
            elif isinstance(data, (int, float)):
                data = data * torch.ones(self._data.shape)
            else:
                raise TypeError(
                    "Expecting a Torch tensor, numpy array, a list/tuple of int/float, or an int/float for"
                    " 'data'")
            if self._torch_data.shape != data.shape:
                raise ValueError(
                    "The given data does not have the same shape as previously."
                )

            # clip the value using the space
            if self.has_space():
                if self.is_continuous():  # continuous case
                    low, high = torch.from_numpy(
                        self._space.low), torch.from_numpy(self._space.high)
                    data = torch.min(torch.max(data, low), high)
                else:  # discrete case
                    n = self._space.n
                    if data.size == 1:
                        data = torch.clamp(data, min=0, max=n)

            # set data
            self._torch_data = data
            self._data = data.detach().numpy(
            ) if data.requires_grad else data.numpy()
            self.torch_window.append(self._torch_data)
            self.window.append(self._data)

            # check that the window is full: if not, copy the last data
            if len(self.window) != self.window.maxsize:
                for _ in range(self.window.maxsize - len(self.window)):
                    # copy last appended data
                    self.window.append(self.window[-1])
                    self.torch_window.append(self.torch_window[-1])

    @property
    def merged_torch_data(self):
        """
        Return the merged torch data.

        Returns:
            list of torch.Tensor: list of data torch tensors.
        """
        # fuse the data
        fused_state = self.fuse()
        # return the data
        return fused_state.torch_data

    @property
    def last_torch_data(self):
        """Return the last provided torch data."""
        if self.has_data():
            return self.torch_window[-1]
        return [state.last_torch_data for state in self._states]

    @property
    def vec_data(self):
        """
        Return a vectorized form of the data.

        Returns:
            np.array[N]: all the data.
        """
        return np.concatenate([data.reshape(-1) for data in self.merged_data])

    @property
    def vec_torch_data(self):
        """
        Return a vectorized form of all the torch tensors.

        Returns:
            torch.Tensor([N]): all the torch tensors reshaped such that they are unidimensional.
        """
        return torch.cat([data.reshape(-1) for data in self.merged_torch_data])

    @property
    def spaces(self):
        """
        Get the corresponding spaces as a list of spaces.
        """
        if self.has_space():
            return [self._space]
        return [state._space for state in self._states]

    @property
    def space(self):
        """
        Get the corresponding space.
        """
        if self.has_space():
            # return [self._space]
            # return gym.spaces.Tuple([self._space])
            return self._space
        # return [state._space for state in self._states]
        return gym.spaces.Tuple([state._space for state in self._states])

    @space.setter
    def space(self, space):
        """
        Set the corresponding space. This can only be used one time!
        """
        if self.has_data() and not self.has_space() and isinstance(
                space, (gym.spaces.Box, gym.spaces.Discrete)):
            self._space = space

    @property
    def merged_space(self):
        """
        Get the corresponding merged space. Note that all the spaces have to be of the same type.
        """
        if self.has_space():
            return self._space
        spaces = self.spaces
        result = []
        dtype, prev_dtype = None, None
        for space in spaces:
            if isinstance(space, gym.spaces.Box):
                dtype = 'box'
                result.append([space.low, space.high])
            elif isinstance(space, gym.spaces.Discrete):
                dtype = 'discrete'
                result.append(space.n)
            else:
                raise NotImplementedError

            if prev_dtype is not None and dtype != prev_dtype:
                return self.space

            prev_dtype = dtype

        if dtype == 'box':
            low = np.concatenate([res[0] for res in result])
            high = np.concatenate([res[1] for res in result])
            return gym.spaces.Box(low=low, high=high, dtype=np.float32)
        elif dtype == 'discrete':
            return gym.spaces.Discrete(n=np.sum(result))

        return self.space

    @property
    def name(self):
        """
        Return the name of the state.
        """
        if self._name is None:
            return self.__class__.__name__
        return self._name

    @name.setter
    def name(self, name):
        """
        Set the name of the state.
        """
        if name is None:
            name = self.__class__.__name__
        if not isinstance(name, str):
            raise TypeError("Expecting the name to be a string.")
        self._name = name

    @property
    def shape(self):
        """
        Return the shape of each state. Some states, such as camera states have more than 1 dimension.
        """
        return [data.shape for data in self.data]

    @property
    def merged_shape(self):
        """
        Return the shape of each merged state.
        """
        return [data.shape for data in self.merged_data]

    @property
    def size(self):
        """
        Return the size of each state.
        """
        return [data.size for data in self.data]

    @property
    def merged_size(self):
        """
        Return the size of each merged state.
        """
        return [data.size for data in self.merged_data]

    @property
    def dimension(self):
        """
        Return the dimension (length of shape) of each state.
        """
        return [len(data.shape) for data in self.data]

    @property
    def merged_dimension(self):
        """
        Return the dimension (length of shape) of each merged state.
        """
        return [len(data.shape) for data in self.merged_data]

    @property
    def num_dimensions(self):
        """
        Return the number of different dimensions (length of shape).
        """
        return len(np.unique(self.dimension))

    # @property
    # def distribution(self):
    #     """
    #     Get the current distribution used when sampling the state
    #     """
    #     pass
    #
    # @distribution.setter
    # def distribution(self, distribution):
    #     """
    #     Set the distribution to the state.
    #     """
    #     # check if distribution is discrete/continuous
    #     pass

    @property
    def in_training_mode(self):
        """Return True if we are in training mode."""
        return self._training_mode

    @property
    def window(self):
        """Return the window."""
        return self._window

    @property
    def torch_window(self):
        """Return the torch window."""
        return self._torch_window

    @property
    def window_size(self):
        """Return the window size."""
        return self.window.maxsize

    @window_size.setter
    def window_size(self, size):
        """Set the window size."""
        if size is None:
            size = 1
        if not isinstance(size, int):
            raise TypeError(
                "Expecting the given window size to be an int, instead got: {}"
                .format(type(size)))
        size = size if size > 0 else 1

        # create windows
        self._window = FIFOQueue(maxsize=size)
        self._torch_window = FIFOQueue(maxsize=size)

        # add data if present
        if self._data is not None:
            self._window.append(self._data)
            self._torch_window.append(self._torch_data)

            # check that the window is full: if not, copy the last data
            if len(self._window) != self._window.maxsize:
                for _ in range(self._window.maxsize - len(self._window)):
                    # copy last appended data
                    self._window.append(self._window[-1])
                    self._torch_window.append(self._torch_window[-1])

    @property
    def axis(self):
        """Return the axis to concatenate or stack the states in the current window."""
        return self._axis

    @axis.setter
    def axis(self, axis):
        """Set the axis to concatenate or stack the states in the current window."""
        if axis is not None and not isinstance(axis, int):
            raise TypeError(
                "Expecting the given axis to be None (concatenate) or an int (stack), instead got: "
                "{}".format(type(axis)))
        self._axis = axis

    @property
    def ticks(self):
        """Return the number of ticks to sleep before getting the next state data."""
        return self._ticks

    @ticks.setter
    def ticks(self, ticks):
        """Set the number of ticks to sleep before getting the next state data."""
        ticks = int(ticks)
        if ticks < 1:
            ticks = 1
        self._ticks = ticks

    ###########
    # Methods #
    ###########

    def train(self):
        """Set the state in training mode."""
        self._training_mode = True

    def eval(self):
        """Set the state in evaluation / test mode."""
        self._training_mode = False

    def is_combined_states(self):
        """
        Return a boolean value depending if the state is a combination of states.

        Returns:
            bool: True if the state is a combination of states, False otherwise.
        """
        return len(self._states) > 0

    # alias
    has_states = is_combined_states

    def has_data(self):
        """Check if the state has data."""
        return self._data is not None
        # return len(self._states) == 0

    def has_space(self):
        """Check if the state has a space."""
        return self._space is not None

    def add(self, state):
        """
        Add a state or a list of states to the list of internal states. Useful when combining different states together.
        This shouldn't be called if this state has some data set to it.

        Args:
            state (State, list/tuple of State): state(s) to add to the internal list of states
        """
        if self.has_data():
            raise AttributeError(
                "Undefined behavior: a state should be a combination of states or should contain "
                "some kind of data, but not both.")
        if isinstance(state, State):
            self._states.add(state)
        elif isinstance(state, collections.Iterable):
            for i, s in enumerate(state):
                if not isinstance(s, State):
                    raise TypeError(
                        "The item {} in the given list is not an instance of State"
                        .format(i))
                self._states.add(s)
        else:
            raise TypeError(
                "The 'other' argument should be an instance of State, or an iterator over states."
            )

    # alias
    append = add
    extend = add

    def _read(self):
        """
        Read the state value. This has to be overwritten in the child class.
        """
        pass

    def read(self, return_data=True, merged_data=False):
        """
        Read the state values from the simulator for each state, set it and return their values.
        """
        # if time to read
        if self._cnt % self.ticks == 0:

            # if multiple states, read each state
            if self.has_states():  # read each state
                for state in self.states:
                    state._read()
            else:  # else, read the current state
                self._read()

        # increment counter
        self._cnt += 1

        # return the data
        if return_data:
            if merged_data:
                return self.merged_data
            return self.data

    def _reset(self):
        """
        Reset the state. This has to be overwritten in the child class.
        """
        self._cnt = 0
        self._read()

    def reset(self, return_data=True, merged_data=False):
        """
        Some states need to be reset. It returns the initial state.
        """
        self._cnt = 0

        # if multiple states, reset each state
        if self.has_states():
            for state in self.states:
                state._reset()
        else:  # else, reset this state
            self._reset()

        # return the first state data if specified
        if return_data:
            if merged_data:
                return self.merged_data
            return self.data  # self.read()

    def max_dimension(self):
        """
        Return the maximum dimension.
        """
        return max(self.dimension)

    def total_size(self):
        """
        Return the total size of the combined state.
        """
        return sum(self.size)

    def has_discrete_values(self):
        """
        Does the state have discrete values?
        """
        if self._data is None:
            return [
                isinstance(state._space, gym.spaces.Discrete)
                for state in self._states
            ]
        if isinstance(self._space, gym.spaces.Discrete):
            return [True]
        return [False]

    def is_discrete(self):
        """
        If all the states are discrete, then it is discrete.
        """
        values = self.has_discrete_values()
        if len(values) == 0:
            return False
        return all(values)

    def has_continuous_values(self):
        """
        Does the state have continuous values?
        """
        if self._data is None:
            return [
                isinstance(state._space, gym.spaces.Box)
                for state in self._states
            ]
        if isinstance(self._space, gym.spaces.Box):
            return [True]
        return [False]

    def is_continuous(self):
        """
        If one of the state is continuous, then the state is considered to be continuous.
        """
        return any(self.has_continuous_values())

    def bounds(self):
        """
        If the state is continuous, it returns the lower and higher bounds of the state.
        If the state is discrete, it returns the maximum number of discrete values that the state can take.

        Returns:
            list/tuple: list of bounds if multiple states, or bounds of this state
        """
        if self._data is None:
            return [state.bounds() for state in self._states]
        if isinstance(self._space, gym.spaces.Box):
            return self._space.low, self._space.high
        elif isinstance(self._space, gym.spaces.Discrete):
            return (self._space.n, )
        raise NotImplementedError

    def apply(self, fct):
        """
        Apply the given fct to the data of the state, and set it to the state.
        """
        self.data = fct(self.data)

    def contains(self, x):  # parameter dependent of the state
        """
        Check if the argument is within the range/bound of the state.
        """
        return self._space.contains(x)

    def sample(
        self,
        distribution=None
    ):  # parameter dependent of the state (discrete and continuous distributions)
        """
        Sample some values from the state based on the given distribution.
        If no distribution is specified, it samples from a uniform distribution (default value).
        """
        if self.is_combined_states():
            return [state.sample() for state in self._states]
        if self._distribution is None:  # uniform distribution
            return self._space.sample()
        else:
            pass
        raise NotImplementedError

    def add_noise(self,
                  noise=None,
                  replace=True):  # parameter dependent of the state
        """
        Add some noise to the state, and returns it.

        Args:
            noise (np.ndarray, fct): array to be added or function to be applied on the data
        """
        if self._data is None:
            # apply noise
            for state in self._states:
                state.add_noise(noise=noise)
        else:
            # add noise to the data
            noisy_data = self.data[0] + noise
            # clip such that the data is within the bounds
            self.data = noisy_data

    def normalize(self,
                  normalizer=None,
                  replace=True):  # parameter dependent of the state
        """
        Normalize using the state data using the provided normalizer.

        Args:
            normalizer (sklearn.preprocessing.Normalizer): the normalizer to apply to the data.
            replace (bool): if True, it will replace the `data` attribute by the normalized data.

        Returns:
            the normalized data
        """
        pass

    def fuse(self, other=None, axis=0):
        """
        Fuse the states that have the same shape together. The axis specified along which axis we concatenate the data.
        If multiple states with different shapes are present, the axis will be the one specified if possible, otherwise
        it will be min(dimension, axis).

        Examples:
            s0 = JointPositionState(robot)
            s1 = JointVelocityState(robot)
            s = s0 & s1
            print(s)
            print(s.shape)
            s = s0 + s1
            s.fuse()
            print(s)
            print(s.shape)
        """
        # check argument
        if not (other is None or isinstance(other, State)):
            raise TypeError(
                "The 'other' argument should be None or another state.")

        # build list of all the states
        states = [self] if self.has_data() else self._states
        if other is not None:
            if other.has_data():
                states.append(other)
            else:
                states.extend(other._states)

        # check if only one state
        if len(states) < 2:
            return self  # do nothing

        # build the dictionary with key=dimension of shape, value=list of states
        dic = {}
        for state in states:
            dic.setdefault(len(state.data[0].shape), []).append(state)

        # traverse the dictionary and fuse corresponding shapes
        states = []
        for key, value in dic.items():
            if len(value) > 1:
                # fuse
                data = [state.data[0] for state in value]
                names = [state.name for state in value]
                s = State(data=np.concatenate(data, axis=min(axis, key)),
                          name='+'.join(names))
                states.append(s)
            else:
                # only one state
                states.append(value[0])

        # return the fused state
        if len(states) == 1:
            return states[0]
        return State(states)

    def lookfor(self, class_type):
        """
        Look for the specified class type/name in the list of internal states, and returns it.

        Args:
            class_type (type, str): class type or name

        Returns:
            State, None: the corresponding instance of the State class. None if it was not found.
        """
        # if string, lowercase it
        if isinstance(class_type, str):
            class_type = class_type.lower()

        # if there is one state
        if self.has_data():
            if self.__class__ == class_type or self.__class__.__name__.lower(
            ) == class_type or self.name == class_type:
                return self

        # the state has multiple states, thus we go through each state
        for state in self.states:
            if state.__class__ == class_type or state.__class__.__name__.lower() == class_type or \
                    state.name == class_type:
                return state

    #############
    # Operators #
    #############

    def __str__(self):
        """Return a representation string about the object."""
        if self._data is None:
            lst = [self.__class__.__name__ + '(']
            for state in self.states:
                lst.append('\t' + state.__str__() + ',')
            lst.append(')')
            return '\n'.join(lst)
        else:
            return '%s(%s)' % (self.name, self.data[0])

    # def __str__(self):
    #     """
    #     String to represent the state. Need to be provided by each child class.
    #     """
    #     if self._data is None:
    #         return [str(state) for state in self._states]
    #     return str(self)

    def __call__(self, return_data=True, merged_data=False):
        """
        Compute/read the state and return it. It is an alias to the `self.read()` method.
        """
        return self.read(return_data=return_data, merged_data=merged_data)

    def __len__(self):
        """
        Return the total number of states contained in this class.

        Example::

            s1 = JntPositionState(robot)
            s2 = s1 + JntVelocityState(robot)
            print(len(s1)) # returns 1
            print(len(s2)) # returns 2
        """
        if self._data is None:
            return len(self._states)
        return 1

    def __iter__(self):
        """
        Iterator over the states.
        """
        if self.is_combined_states():
            for state in self._states:
                yield state
        else:
            yield self

    def __contains__(self, item):
        """
        Check if the state item(s) is(are) in the combined state. If the item is the data associated with the state,
        it checks that it is within the bounds.

        Args:
            item (State, list/tuple of state, type): check if given state(s) is(are) in the combined state

        Example:
            s1 = JointPositionState(robot)
            s2 = JointVelocityState(robot)
            s = s1 + s2
            print(s1 in s) # output True
            print(s2 in s1) # output False
            print((s1, s2) in s) # output True
        """
        # check type of item
        if not isinstance(item, (State, np.ndarray, type)):
            raise TypeError(
                "Expecting a State, np.array, or a class type, instead got: {}"
                .format(type(item)))

        # if class type
        if isinstance(item, type):
            # if there is one state
            if self.has_data():
                return self.__class__ == item
            # the state has multiple states, thus we go through each state
            for state in self.states:
                if state.__class__ == item:
                    return True
            return False

        # check if state item is in the combined state
        if self._data is None and isinstance(item, State):
            return item in self._states

        # check if state/data is within the bounds
        if isinstance(item, State):
            item = item.data

        # check if continuous
        # if self.is_continuous():
        #     low, high = self.bounds()
        #     return np.all(low <= item) and np.all(item <= high)
        # else: # discrete case
        #     num = self.bounds()[0]
        #     # check the size of data
        #     if item.size > 1: # array
        #         return (item.size < num)
        #     else: # one number
        #         return (item[0] < num)

        return self.contains(item)

    def __getitem__(self, key):
        """
        Get the corresponding item from the state(s)
        """
        # if one state, slice the corresponding state data
        if len(self._states) == 0:
            return self.data[0][key]
        # if multiple states
        if isinstance(key, int):
            # get one state
            return self._states[key]
        elif isinstance(key, slice):
            # get multiple states
            return State(self._states[key])
        else:
            raise TypeError(
                "Expecting an int or slice for the key, but got instead {}".
                format(type(key)))

    def __setitem__(self, key, value):
        """
        Set the corresponding item/value to the corresponding key.

        Args:
            key (int, slice): index of the internal state, or index/indices for the state data
            value (State, int/float, array): value to be set
        """
        if self.is_combined_states():
            # set/move the state to the specified key
            if isinstance(value, State) and isinstance(key, int):
                self._states[key] = value
            else:
                raise TypeError(
                    "Expecting key to be an int, and value to be a state.")
        else:
            # set the value on the data directly
            self.data[0][key] = value

    def __add__(self, other):
        """
        Combine two different states together. In this special case, the operation is not commutable.
        This is the same as taking the union of the states.

        Args:
            other (State): another state

        Returns:
            State: the combined state

        Examples:
            s1 = JntPositionState(robot)
            s2 = JntVelocityState(robot)
            s = s1 + s2     # = State([JntPositionState(robot), JntVelocityState(robot)])

            s1 = State([JntPositionState(robot), JntVelocityState(robot)])
            s2 = State([JntPositionState(robot), LinkPositionState(robot)])
            s = s1 + s2     # = State([JntPositionState(robot), JntVelocityState(robot), LinkPositionState(robot)])
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        s1 = self._states if self._data is None else OrderedSet([self])
        s2 = other._states if other._data is None else OrderedSet([other])
        s = s1 + s2
        return State(s)

    def __iadd__(self, other):
        """
        Add a state to the current one.

        Args:
            other (State, list/tuple of State): other state

        Examples:
            s = State()
            s += JntPositionState(robot)
            s += JntVelocityState(robot)
        """
        if self._data is not None:
            raise AttributeError(
                "The current class already has some data attached to it. This operation can not be "
                "applied in this case.")
        self.append(other)

    def __sub__(self, other):
        """
        Remove the other state(s) from the current state.

        Args:
            other (State): state to be removed.
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        s1 = self._states if self._data is None else OrderedSet([self])
        s2 = other._states if other._data is None else OrderedSet([other])
        s = s1 - s2
        if len(s) == 1:  # just one element
            return s[0]
        return State(s)

    def __isub__(self, other):
        """
        Remove one or several states from the combined state.

        Args:
            other (State): state to be removed.
        """
        if not isinstance(other, State):
            raise TypeError("Expecting another state, instead got {}".format(
                type(other)))
        if self._data is not None:
            raise RuntimeError(
                "This operation is only available for a combined state")
        s = other._states if other._data is None else OrderedSet([other])
        self._states -= s

    def __and__(self, other):
        """
        Fuse two states together; only one data for the two states, instead of a data for each state as done
        when combining the states. All the states must have the same dimensions, and it fuses the data along
        the axis=0.

        Args:
            other: the other (combined) state

        Returns:
            State: the intersection of states

        Examples:
            s0 = JntPositionState(robot)
            s1 = JntVelocityState(robot)
            print(s0.shape)
            print(s1.shape)
            s = s0 + s1
            print(s.shape)  # prints [s0.shape, s1.shape]
            s = s0 & s1
            print(s.shape)  # prints np.concatenate((s0,s1)).shape
        """
        return self.fuse(other, axis=0)

    def __copy__(self):
        """Return a shallow copy of the state. This can be overridden in the child class."""
        return self.__class__(states=self.states,
                              data=self._data,
                              space=self._space,
                              name=self.name,
                              window_size=self.window_size,
                              axis=self.axis,
                              ticks=self.ticks)

    def __deepcopy__(self, memo={}):
        """Return a deep copy of the state. This can be overridden in the child class.

        Args:
            memo (dict): memo dictionary of objects already copied during the current copying pass
        """
        if self in memo:
            return memo[self]

        states = [copy.deepcopy(state, memo) for state in self.states]
        data = copy.deepcopy(self.window[0]) if self.has_data() else None
        space = copy.deepcopy(self._space)
        state = self.__class__(states=states,
                               data=data,
                               space=space,
                               name=self.name,
                               window_size=self.window_size,
                               axis=self.axis,
                               ticks=self.ticks)

        memo[self] = state
        return state
Exemplo n.º 10
0
class Action(object):
    r"""Action class.

    The `Action` is produced by the policy in response to a certain state/observation. From a programming point of
    view, compared to the `State` class, the action is a setter object. Thus, they have a very close relationship
    and share many functionalities. Some actions are mutually exclusive and cannot be executed at the same time.

    An action is defined as something that affects the environment; that forces the environment to go to the next
    state. For instance, an action could be the desired joint positions, but also an abstract action such as
    'open a door' which would then open a door in the simulator and load the next part of the world.

    In our framework, the `Action` class is decoupled from the policy and environment rendering it more modular [1].
    Nevertheless, the `Action` class still acts as a bridge between the policy and environment. In addition to be
    the output of a policy/controller, it can also be the input to some value approximators, dynamic models, reward
    functions, and so on.

    This class also describes the `action_space` which has initially been defined in `gym.Env` [2].

    References:
        [1] "Wikipedia: Composition over Inheritance", https://en.wikipedia.org/wiki/Composition_over_inheritance
        [2] "OpenAI gym": https://gym.openai.com/   and    https://github.com/openai/gym
    """
    def __init__(self, actions=(), data=None, space=None, name=None, ticks=1):
        """
        Initialize the action. The action contains some kind of data, or is a combination of other actions.

        Args:
            actions (list/tuple of Action): list of actions to be combined together (if given, we can not specified
                                            data)
            data (np.ndarray): data associated to this action
            space (gym.space): space associated with the given data
            ticks (int): number of ticks to sleep before setting the next action data.

        Warning:
            Both arguments can not be provided to the action.
        """
        # Check arguments
        if actions is None:
            actions = tuple()

        if not isinstance(actions, (list, tuple, set, OrderedSet)):
            raise TypeError(
                "Expecting a list, tuple, or (ordered) set of actions.")
        if len(actions) > 0 and data is not None:
            raise ValueError(
                "Please specify only one of the argument `actions` xor `data`, but not both."
            )

        # Check if data is given
        if data is not None:
            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                elif isinstance(data, (int, float)):
                    data = np.array([data])
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

        # The following attributes should normally be set in the child classes
        self._data = data
        self._torch_data = data if data is None else torch.from_numpy(
            data).float()
        self._space = space
        self._distribution = None  # for sampling
        self._normalizer = None
        self._noiser = None  # for noise
        self.name = name

        # create ordered set which is useful if this action is a combination of multiple actions
        self._actions = OrderedSet()
        if self._data is None:
            self.add(actions)

        # set ticks and counter
        self.cnt = 0
        self.ticks = int(ticks)

        # reset action
        # self.reset()

    ##############################
    # Properties (Getter/Setter) #
    ##############################

    @property
    def actions(self):
        """
        Get the list of actions.
        """
        return self._actions

    @actions.setter
    def actions(self, actions):
        """
        Set the list of actions.
        """
        if self.has_data():
            raise AttributeError(
                "Trying to add internal actions to the current action while it already has some data. "
                "A action should be a combination of actions or should contain some kind of data, "
                "but not both.")
        if isinstance(actions, collections.Iterable):
            for action in actions:
                if not isinstance(action, Action):
                    raise TypeError(
                        "One of the given actions is not an instance of Action."
                    )
                self.add(action)
        else:
            raise TypeError(
                "Expecting an iterator (e.g. list, tuple, OrderedSet, set,...) over actions"
            )

    @property
    def data(self):
        """
        Get the data associated to this particular action, or the combined data associated to each action.

        Returns:
            list of np.ndarray: list of data associated to the action
        """
        if self.has_data():
            return [self._data]
        return [action._data for action in self._actions]

    @data.setter
    def data(self, data):
        """
        Set the data associated to this particular action, or the combined data associated to each action.
        Each data will be clipped if outside the range/bounds of the corresponding action.

        Args:
            data: the data to set
        """
        if self.has_actions():  # combined actions
            if not isinstance(data, collections.Iterable):
                raise TypeError("data is not an iterator")
            if len(self._actions) != len(data):
                raise ValueError(
                    "The number of actions is different from the number of data segments"
                )
            for action, d in zip(self._actions, data):
                action.data = d

        # one action: change the data
        # if self.has_data():
        else:
            if self.is_discrete():  # discrete action
                if isinstance(data,
                              np.ndarray):  # data action is a numpy array
                    # check if given logits or not
                    if data.shape[-1] != 1:  # logits
                        data = np.array([np.argmax(data)])
                elif isinstance(data, (float, np.integer)):
                    data = int(data)
                else:
                    raise TypeError(
                        "Expecting the `data` action to be an int, numpy array, instead got: "
                        "{}".format(type(data)))

            if not isinstance(data, np.ndarray):
                if isinstance(data, (list, tuple)):
                    data = np.array(data)
                    if len(
                            data
                    ) == 1 and self._data.shape != data.shape:  # TODO: check this line
                        data = data[0]
                elif isinstance(
                        data,
                    (int, float, np.integer)):  # np.integer is for Py3.5
                    data = data * np.ones(self._data.shape)
                else:
                    raise TypeError(
                        "Expecting a numpy array, a list/tuple of int/float, or an int/float for 'data'"
                    )

            if self._data is not None and self._data.shape != data.shape:
                raise ValueError(
                    "The given data does not have the same shape as previously."
                )

            # clip the value using the space
            if self.has_space():
                if self.is_continuous():  # continuous case
                    low, high = self._space.low, self._space.high
                    data = np.clip(data, low, high)
                else:  # discrete case
                    n = self._space.n
                    if data.size == 1:
                        data = np.clip(data, 0, n)
            self._data = data
            self._torch_data = torch.from_numpy(data).float()

    @property
    def merged_data(self):
        """
        Return the merged data.
        """
        # fuse the data
        fused_action = self.fuse()
        # return the data
        return fused_action.data

    @property
    def torch_data(self):
        """
        Return the data as a list of torch tensors.
        """
        if self.has_data():
            return [self._torch_data]
        return [action._torch_data for action in self._actions]

    @torch_data.setter
    def torch_data(self, data):
        """
        Set the torch data and update the numpy version of the data.

        Args:
            data (torch.Tensor, list of torch.Tensors): data to set.
        """
        if self.has_actions():  # combined actions
            if not isinstance(data, collections.Iterable):
                raise TypeError("data is not an iterator")
            if len(self._actions) != len(data):
                raise ValueError(
                    "The number of actions is different from the number of data segments"
                )
            for action, d in zip(self._actions, data):
                action.torch_data = d

        # one action: change the data
        # if self.has_data():
        else:
            if isinstance(data, torch.Tensor):
                data = data.float()
            elif isinstance(data, np.ndarray):
                data = torch.from_numpy(data).float()
            elif isinstance(data, (list, tuple)):
                data = torch.from_numpy(np.array(data)).float()
            elif isinstance(data, (int, float)):
                data = data * torch.ones(self._data.shape)
            else:
                raise TypeError(
                    "Expecting a Torch tensor, numpy array, a list/tuple of int/float, or an int/float for"
                    " 'data'")

            if self._torch_data.shape != data.shape:
                raise ValueError(
                    "The given data does not have the same shape as previously."
                )

            # clip the value using the space
            if self.has_space():
                if self.is_continuous():  # continuous case
                    low, high = torch.from_numpy(
                        self._space.low), torch.from_numpy(self._space.high)
                    data = torch.min(torch.max(data, low), high)
                else:  # discrete case
                    n = self._space.n
                    if data.size == 1:
                        data = torch.clamp(data, min=0, max=n)
            self._torch_data = data
            if data.requires_grad:
                data = data.detach().numpy()
            else:
                data = data.numpy()
            self._data = data

    @property
    def merged_torch_data(self):
        """
        Return the merged torch data.

        Returns:
            list of torch.Tensor: list of data torch tensors.
        """
        # fuse the data
        fused_action = self.fuse()
        # return the data
        return fused_action.torch_data

    @property
    def vec_data(self):
        """
        Return a vectorized form of the data.

        Returns:
            np.array[N]: all the data.
        """
        return np.concatenate([data.reshape(-1) for data in self.merged_data])

    @property
    def vec_torch_data(self):
        """
        Return a vectorized form of all the torch tensors.

        Returns:
            torch.Tensor([N]): all the torch tensors reshaped such that they are unidimensional.
        """
        return torch.cat([data.reshape(-1) for data in self.merged_torch_data])

    @property
    def spaces(self):
        """
        Get the corresponding spaces as a list of spaces.
        """
        if self.has_space():
            return [self._space]
        return [action._space for action in self._actions]

    @property
    def space(self):
        """
        Get the corresponding space.
        """
        if self.has_space():
            # return gym.spaces.Tuple([self._space])
            return self._space
        # return [action._space for action in self._actions]
        return gym.spaces.Tuple([action._space for action in self._actions])

    @space.setter
    def space(self, space):
        """
        Set the corresponding space. This can only be used one time!
        """
        if self.has_data() and not self.has_space() and \
                isinstance(space, (gym.spaces.Box, gym.spaces.Discrete, gym.spaces.MultiDiscrete)):
            self._space = space

    @property
    def merged_space(self):
        """
        Get the corresponding merged space. Note that all the spaces have to be of the same type.
        """
        if self.has_space():
            return self._space
        spaces = self.spaces
        result = []
        dtype, prev_dtype = None, None
        for space in spaces:
            if isinstance(space, gym.spaces.Box):
                dtype = 'box'
                result.append([space.low, space.high])
            elif isinstance(space, gym.spaces.Discrete):
                dtype = 'discrete'
                result.append(space.n)
            else:
                raise NotImplementedError

            if prev_dtype is not None and dtype != prev_dtype:
                return self.space

            prev_dtype = dtype

        if dtype == 'box':
            low = np.concatenate([res[0] for res in result])
            high = np.concatenate([res[1] for res in result])
            return gym.spaces.Box(low=low, high=high, dtype=np.float32)
        elif dtype == 'discrete':
            return gym.spaces.Discrete(n=np.sum(result))

        return self.space

    @property
    def name(self):
        """
        Return the name of the action.
        """
        if self._name is None:
            return self.__class__.__name__
        return self._name

    @name.setter
    def name(self, name):
        """
        Set the name of the action.
        """
        if name is None:
            name = self.__class__.__name__
        if not isinstance(name, str):
            raise TypeError("Expecting the name to be a string.")
        self._name = name

    @property
    def shape(self):
        """
        Return the shape of each action. Some actions, such as camera actions have more than 1 dimension.
        """
        # if self.has_actions():
        return [data.shape for data in self.data]
        # return [self.data.shape]

    @property
    def merged_shape(self):
        """
        Return the shape of each merged action.
        """
        return [data.shape for data in self.merged_data]

    @property
    def size(self):
        """
        Return the size of each action.
        """
        # if self.has_actions():
        return [data.size for data in self.data]
        # return [len(self.data)]

    @property
    def merged_size(self):
        """
        Return the size of each merged action.
        """
        return [data.size for data in self.merged_data]

    @property
    def dimension(self):
        """
        Return the dimension (length of shape) of each action.
        """
        return [len(data.shape) for data in self.data]

    @property
    def merged_dimension(self):
        """
        Return the dimension (length of shape) of each merged state.
        """
        return [len(data.shape) for data in self.merged_data]

    @property
    def num_dimensions(self):
        """
        Return the number of different dimensions (length of shape).
        """
        return len(np.unique(self.dimension))

    # @property
    # def distribution(self):
    #     """
    #     Get the current distribution used when sampling the action
    #     """
    #     return None
    #
    # @distribution.setter
    # def distribution(self, distribution):
    #     """
    #     Set the distribution to the action.
    #     """
    #     # check if distribution is discrete/continuous
    #     pass

    ###########
    # Methods #
    ###########

    def is_combined_actions(self):
        """
        Return a boolean value depending if the action is a combination of actions.

        Returns:
            bool: True if the action is a combination of actions, False otherwise.
        """
        return len(self._actions) > 0

    # alias
    has_actions = is_combined_actions

    def has_data(self):
        return self._data is not None

    def has_space(self):
        return self._space is not None

    def add(self, action):
        """
        Add a action or a list of actions to the list of internal actions. Useful when combining different actions
        together. This shouldn't be called if this action has some data set to it.

        Args:
            action (Action, list/tuple of Action): action(s) to add to the internal list of actions
        """
        if self.has_data():
            raise AttributeError(
                "Undefined behavior: a action should be a combination of actions or should contain "
                "some kind of data, but not both.")
        if isinstance(action, Action):
            self._actions.add(action)
        elif isinstance(action, collections.Iterable):
            for i, s in enumerate(action):
                if not isinstance(s, Action):
                    raise TypeError(
                        "The item {} in the given list is not an instance of Action"
                        .format(i))
                self._actions.add(s)
        else:
            raise TypeError(
                "The 'other' argument should be an instance of Action, or an iterator over actions."
            )

    # alias
    append = add
    extend = add

    def _write(self, data):
        pass

    def write(self, data=None):
        """
        Write the action values to the simulator for each action.
        This has to be overwritten by the child class.
        """
        # if time to write
        if self.cnt % self.ticks == 0:

            if self.has_data():  # write the current action
                if data is None:
                    data = self._data
                self._write(data)
            else:  # write each action
                if self.actions:
                    if data is None:
                        data = [None] * len(self.actions)
                    for action, d in zip(self.actions, data):
                        if d is None:
                            d = action._data
                        action._write(d)

        self.cnt += 1

        # return the data
        # return self.data

    # def _reset(self):
    #     pass
    #
    # def reset(self):
    #     """
    #     Some actions need to be reset. It returns the initial action.
    #     This needs to be overwritten by the child class.
    #
    #     Returns:
    #         initial action
    #     """
    #     if self.has_data(): # reset the current action
    #         self._reset()
    #     else: # reset each action
    #         for action in self.actions:
    #             action._reset()
    #
    #     # return the first action data
    #     return self.write()

    # def shape(self):
    #     """
    #     Return the shape of each action. Some actions, such as camera actions have more than 1 dimension.
    #     """
    #     return [d.shape for d in self.data]
    #
    # def dimension(self):
    #     """
    #     Return the dimension (length of shape) of each action.
    #     """
    #     return [len(d.shape) for d in self.data]

    def max_dimension(self):
        """
        Return the maximum dimension.
        """
        return max(self.dimension)

    # def size(self):
    #     """
    #     Return the size of each action.
    #     """
    #     return [d.size for d in self.data]

    def total_size(self):
        """
        Return the total size of the combined action.
        """
        return sum(self.size)

    def has_discrete_values(self):
        """
        Does the action have discrete values?
        """
        if self._data is None:
            return [
                isinstance(action._space,
                           (gym.spaces.Discrete, gym.spaces.MultiDiscrete))
                for action in self._actions
            ]
        if isinstance(self._space,
                      (gym.spaces.Discrete, gym.spaces.MultiDiscrete)):
            return [True]
        return [False]

    def is_discrete(self):
        """
        If all the actions are discrete, then it is discrete.
        """
        values = self.has_discrete_values()
        if len(values) == 0:
            return False
        return all(values)

    def has_continuous_values(self):
        """
        Does the action have continuous values?
        """
        if self._data is None:
            return [
                isinstance(action._space, gym.spaces.Box)
                for action in self._actions
            ]
        if isinstance(self._space, gym.spaces.Box):
            return [True]
        return [False]

    def is_continuous(self):
        """
        If one of the action is continuous, then the action is considered to be continuous.
        """
        return any(self.has_continuous_values())

    def bounds(self):
        """
        If the action is continuous, it returns the lower and higher bounds of the action.
        If the action is discrete, it returns the maximum number of discrete values that the action can take.
        If the action is multi-discrete, it returns the maximum number of discrete values that each subaction can take.

        Returns:
            list/tuple: list of bounds if multiple actions, or bounds of this action
        """
        if self._data is None:
            return [action.bounds() for action in self._actions]
        if isinstance(self._space, gym.spaces.Box):
            return (self._space.low, self._space.high)
        elif isinstance(self._space, gym.spaces.Discrete):
            return (self._space.n, )
        elif isinstance(self._space, gym.spaces.MultiDiscrete):
            return (self._space.nvec, )
        raise NotImplementedError

    def apply(self, fct):
        """
        Apply the given fct to the data of the action, and set it to the action.
        """
        self.data = fct(self.data)

    def contains(self, x):  # parameter dependent of the action
        """
        Check if the argument is within the range/bound of the action.
        """
        return self._space.contains(x)

    def sample(
        self,
        distribution=None
    ):  # parameter dependent of the action (discrete and continuous distributions)
        """
        Sample some values from the action based on the given distribution.
        If no distribution is specified, it samples from a uniform distribution (default value).
        """
        if self.is_combined_actions():
            return [action.sample() for action in self._actions]
        if self._distribution is None:
            return
        else:
            pass
        raise NotImplementedError

    def add_noise(self,
                  noise=None,
                  replace=True):  # parameter dependent of the action
        """
        Add some noise to the action, and returns it.

        Args:
            noise (np.ndarray, fct): array to be added or function to be applied on the data
        """
        if self._data is None:
            # apply noise
            for action in self._actions:
                action.add_noise(noise=noise)
        else:
            # add noise to the data
            noisy_data = self.data + noise
            # clip such that the data is within the bounds
            self.data = noisy_data

    def normalize(self,
                  normalizer=None,
                  replace=True):  # parameter dependent of the action
        """
        Normalize using the action data using the provided normalizer.

        Args:
            normalizer (sklearn.preprocessing.Normalizer): the normalizer to apply to the data.
            replace (bool): if True, it will replace the `data` attribute by the normalized data.

        Returns:
            the normalized data
        """
        pass

    def fuse(self, other=None, axis=0):
        """
        Fuse the actions that have the same shape together. The axis specified along which axis we concatenate the data.
        If multiple actions with different shapes are present, the axis will be the one specified if possible,
        otherwise it will be min(dimension, axis).

        Examples:
            a0 = JointPositionAction(robot)
            a1 = JointVelocityAction(robot)
            a = a0 & a1
            print(a)
            print(a.shape)
            a = a0 + a1
            a.fuse()
            print(a)
            print(a.shape)
        """
        # check argument
        if not (other is None or isinstance(other, Action)):
            raise TypeError(
                "The 'other' argument should be None or another action.")

        # build list of all the actions
        actions = [self] if self.has_data() else self._actions
        if other is not None:
            if other.has_data():
                actions.append(other)
            else:
                actions.extend(other._actions)

        # check if only one action
        if len(actions) < 2:
            return self  # do nothing

        # build the dictionary with key=dimension of shape, value=list of actions
        dic = {}
        for action in actions:
            dic.setdefault(len(action._data.shape), []).append(action)

        # traverse the dictionary and fuse corresponding shapes
        actions = []
        for key, value in dic.items():
            if len(value) > 1:
                # fuse
                data = [action._data for action in value]
                names = [action.name for action in value]
                a = Action(data=np.concatenate(data, axis=min(axis, key)),
                           name='+'.join(names))
                actions.append(a)
            else:
                # only one action
                actions.append(value[0])

        # return the fused action
        if len(actions) == 1:
            return actions[0]
        return Action(actions)

    def lookfor(self, class_type):
        """
        Look for the specified class type/name in the list of internal actions, and returns it.

        Args:
            class_type (type, str): class type or name

        Returns:
            Action: the corresponding instance of the Action class
        """
        # if string, lowercase it
        if isinstance(class_type, str):
            class_type = class_type.lower()

        # if there is one action
        if self.has_data():
            if self.__class__ == class_type or self.__class__.__name__.lower(
            ) == class_type:
                return self

        # the action has multiple actions, thus we go through each action
        for action in self.actions:
            if action.__class__ == class_type or action.__class__.__name__.lower(
            ) == class_type:
                return action

    ########################
    # Operator Overloading #
    ########################

    def __str__(self):
        """Return a string describing the action."""
        if self._data is None:
            lst = [self.__class__.__name__ + '(']
            for action in self.actions:
                lst.append('\t' + action.__str__() + ',')
            lst.append(')')
            return '\n'.join(lst)
        else:
            return '%s(%s)' % (self.name, self._data)

    def __call__(self, data=None):
        """
        Compute/read the action and return it. It is an alias to the `self.write()` method.
        """
        return self.write(data)

    def __len__(self):
        """
        Return the total number of actions contained in this class.

        Example::

            s1 = JntPositionAction(robot)
            s2 = s1 + JntVelocityAction(robot)
            print(len(s1)) # returns 1
            print(len(s2)) # returns 2
        """
        if self._data is None:
            return len(self._actions)
        return 1

    def __iter__(self):
        """
        Iterator over the actions.
        """
        if self.is_combined_actions():
            for action in self._actions:
                yield action
        else:
            yield self

    def __contains__(self, item):
        """
        Check if the action item(s) is(are) in the combined action. If the item is the data associated with the action,
        it checks that it is within the bounds.

        Args:
            item (Action, list/tuple of action, type): check if given action(s) is(are) in the combined action

        Example:
            s1 = JntPositionAction(robot)
            s2 = JntVelocityAction(robot)
            s = s1 + s2
            print(s1 in s) # output True
            print(s2 in s1) # output False
            print((s1, s2) in s) # output True
        """
        # check type of item
        if not isinstance(item, (Action, np.ndarray, type)):
            raise TypeError(
                "Expecting an Action, a np.array, or a class type, instead got: {}"
                .format(type(item)))

        # if class type
        if isinstance(item, type):
            # if there is one action
            if self.has_data():
                return self.__class__ == item
            # the action has multiple actions, thus we go through each action
            for action in self.actions:
                if action.__class__ == item:
                    return True
            return False

        # check if action item is in the combined action
        if self._data is None and isinstance(item, Action):
            return item in self._actions

        # check if action/data is within the bounds
        if isinstance(item, Action):
            item = item.data

        # check if continuous
        # if self.is_continuous():
        #     low, high = self.bounds()
        #     return np.all(low <= item) and np.all(item <= high)
        # else: # discrete case
        #     num = self.bounds()[0]
        #     # check the size of data
        #     if item.size > 1: # array
        #         return (item.size < num)
        #     else: # one number
        #         return (item[0] < num)

        return self.contains(item)

    def __getitem__(self, key):
        """
        Get the corresponding item from the action(s)
        """
        # if one action, slice the corresponding action data
        if len(self._actions) == 0:
            return self._data[key]
        # if multiple actions
        if isinstance(key, int):
            # get one action
            return self._actions[key]
        elif isinstance(key, slice):
            # get multiple actions
            return Action(self._actions[key])
        else:
            raise TypeError(
                "Expecting an int or slice for the key, but got instead {}".
                format(type(key)))

    def __setitem__(self, key, value):
        """
        Set the corresponding item/value to the corresponding key.

        Args:
            key (int, slice): index of the internal action, or index/indices for the action data
            value (Action, int/float, array): value to be set
        """
        if self.is_combined_actions():
            # set/move the action to the specified key
            if isinstance(value, Action) and isinstance(key, int):
                self._actions[key] = value
            else:
                raise TypeError(
                    "Expecting key to be an int, and value to be a action.")
        else:
            # set the value on the data directly
            self._data[key] = value

    def __add__(self, other):
        """
        Combine two different actions together. In this special case, the operation is not commutable.
        This is the same as taking the union of the actions.

        Args:
            other (Action): another action

        Returns:
            Action: the combined action

        Examples:
            s1 = JntPositionAction(robot)
            s2 = JntVelocityAction(robot)
            s = s1 + s2     # = Action([JntPositionAction(robot), JntVelocityAction(robot)])

            s1 = Action([JntPositionAction(robot), JntVelocityAction(robot)])
            s2 = Action([JntPositionAction(robot), LinkPositionAction(robot)])
            s = s1 + s2     # = Action([JntPositionAction(robot), JntVelocityAction(robot), LinkPositionAction(robot)])
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        s1 = self._actions if self._data is None else OrderedSet([self])
        s2 = other._actions if other._data is None else OrderedSet([other])
        s = s1 + s2
        return Action(s)

    def __iadd__(self, other):
        """
        Add a action to the current one.

        Args:
            other (Action, list/tuple of Action): other action

        Examples:
            s = Action()
            s += JntPositionAction(robot)
            s += JntVelocityAction(robot)
        """
        if self._data is not None:
            raise AttributeError(
                "The current class already has some data attached to it. This operation can not be "
                "applied in this case.")
        self.append(other)

    def __sub__(self, other):
        """
        Remove the other action(s) from the current action.

        Args:
            other (Action): action to be removed.
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        s1 = self._actions if self._data is None else OrderedSet([self])
        s2 = other._actions if other._data is None else OrderedSet([other])
        s = s1 - s2
        if len(s) == 1:  # just one element
            return s[0]
        return Action(s)

    def __isub__(self, other):
        """
        Remove one or several actions from the combined action.

        Args:
            other (Action): action to be removed.
        """
        if not isinstance(other, Action):
            raise TypeError("Expecting another action, instead got {}".format(
                type(other)))
        if self._data is not None:
            raise RuntimeError(
                "This operation is only available for a combined action")
        s = other._actions if other._data is None else OrderedSet([other])
        self._actions -= s

    def __copy__(self):
        """Return a shallow copy of the action. This can be overridden in the child class."""
        return self.__class__(actions=self.actions,
                              data=self._data,
                              space=self._space,
                              name=self.name,
                              ticks=self.ticks)

    def __deepcopy__(self, memo={}):
        """Return a deep copy of the action. This can be overridden in the child class.

        Args:
            memo (dict): memo dictionary of objects already copied during the current copying pass
        """
        if self in memo:
            return memo[self]

        actions = [copy.deepcopy(action, memo) for action in self.actions]
        data = copy.deepcopy(self._data)
        space = copy.deepcopy(self._space)
        action = self.__class__(actions=actions,
                                data=data,
                                space=space,
                                name=self.name,
                                ticks=self.ticks)

        memo[self] = action
        return action