Exemple #1
0
    def _standardize_user_data(self, s, a, s_next, r, absorbing, theta,
                               check_batch_dim=False):
        """

        Args:
            s (numpy.array): the samples of the state (nsamples, state_dim)
            a (numpy.array): the samples of the state (nsamples, action_dim)
            s_next (numpy.array): the samples of the next (reached) state (nsamples, state_dim)
            r (numpy.array): the sample of the reward (nsamples, )
            theta (numpy.array): the sample of the Q-function parameters (1, n_params)
            check_batch_dim (bool): default False

        Returns:
            The standardized values (s, a, s_next, r, theta)

        """
        s = standardize_input_data(
            s, ['s'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            exception_prefix='state')
        a = standardize_input_data(
            a, ['a'],
            [(None, self.action_dim)] if self.action_dim is not None else None,
            exception_prefix='action')
        # r = standardize_input_data(r, ['r'], [(None, 1)],
        #                            check_batch_dim=False, exception_prefix='reward')
        s_next = standardize_input_data(
            s_next, ['s_next'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            exception_prefix='state_next')
        theta = standardize_input_data(theta, ['theta'],
                                       (None, self.bellman_model.n_inputs()),
                                       exception_prefix='theta')
        check_array_lengths(s, a, s_next)
        return s, a, s_next, r, absorbing, theta
Exemple #2
0
    def _standardize_user_data(self, s, a, s_next, r, check_batch_dim=False):
        """

        Args:
            s (numpy.array): the samples of the state (nsamples, state_dim)
            a (numpy.array): the samples of the state (nsamples, action_dim)
            s_next (numpy.array): the samples of the next (reached) state (nsamples, state_dim)
            r (numpy.array): the sample of the reward (nsamples, )
            check_batch_dim (bool): default False

        Returns:
            The standardized values (s, a, s_next, r, theta)

        """
        s = standardize_input_data(s, ['s'], [(None, self.state_dim)] if self.state_dim is not None else None,
                                   check_batch_dim=check_batch_dim, exception_prefix='state')
        a = standardize_input_data(a, ['a'], [(None, self.action_dim)] if self.action_dim is not None else None,
                                   check_batch_dim=check_batch_dim, exception_prefix='action')
        # r = standardize_input_data(r, ['r'], [(None, 1)],
        #                            check_batch_dim=False, exception_prefix='reward')
        s_next = standardize_input_data(s_next, ['s_next'],
                                        [(None, self.state_dim)] if self.state_dim is not None else None,
                                        check_batch_dim=check_batch_dim, exception_prefix='state_next')
        check_array_lengths(s, a, s_next)
        return s, a, s_next, r
Exemple #3
0
    def _standardize_user_data(self, s, a, s_next, r, check_batch_dim=False):
        """

        Args:
            s (numpy.array): the samples of the state (nsamples, state_dim)
            a (numpy.array): the samples of the state (nsamples, action_dim)
            s_next (numpy.array): the samples of the next (reached) state (nsamples, state_dim)
            r (numpy.array): the sample of the reward (nsamples, )
            check_batch_dim (bool): default False

        Returns:
            The standardized values (s, a, s_next, r, theta)

        """
        s = standardize_input_data(
            s, ['s'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            check_batch_dim=check_batch_dim,
            exception_prefix='state')
        a = standardize_input_data(
            a, ['a'],
            [(None, self.action_dim)] if self.action_dim is not None else None,
            check_batch_dim=check_batch_dim,
            exception_prefix='action')
        # r = standardize_input_data(r, ['r'], [(None, 1)],
        #                            check_batch_dim=False, exception_prefix='reward')
        s_next = standardize_input_data(
            s_next, ['s_next'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            check_batch_dim=check_batch_dim,
            exception_prefix='state_next')
        check_array_lengths(s, a, s_next)
        return s, a, s_next, r
Exemple #4
0
    def _standardize_user_data(self,
                               s,
                               a,
                               s_next,
                               r,
                               absorbing,
                               theta,
                               check_batch_dim=False):
        """

        Args:
            s (numpy.array): the samples of the state (nsamples, state_dim)
            a (numpy.array): the samples of the state (nsamples, action_dim)
            s_next (numpy.array): the samples of the next (reached) state (nsamples, state_dim)
            r (numpy.array): the sample of the reward (nsamples, )
            theta (numpy.array): the sample of the Q-function parameters (1, n_params)
            check_batch_dim (bool): default False

        Returns:
            The standardized values (s, a, s_next, r, theta)

        """
        s = standardize_input_data(
            s, ['s'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            exception_prefix='state')
        a = standardize_input_data(
            a, ['a'],
            [(None, self.action_dim)] if self.action_dim is not None else None,
            exception_prefix='action')
        # r = standardize_input_data(r, ['r'], [(None, 1)],
        #                            check_batch_dim=False, exception_prefix='reward')
        s_next = standardize_input_data(
            s_next, ['s_next'],
            [(None, self.state_dim)] if self.state_dim is not None else None,
            exception_prefix='state_next')
        theta = standardize_input_data(theta, ['theta'],
                                       (None, self.bellman_model.n_inputs()),
                                       exception_prefix='theta')
        check_array_lengths(s, a, s_next)
        return s, a, s_next, r, absorbing, theta