def true_dynamics_pmf(self, s_i, a_i):
        # returns a vector of probability masses the same size as state_centers
        s = self.state_centers[s_i, :].copy()
        u = self.action_centers[a_i]
        pmf = np.zeros(self.state_centers.shape[0])

        mc = self.true_pars[0]
        mp = self.true_pars[1]
        l = self.true_pars[2]
        g = 9.8
        dt = 0.02

        x = s[0]
        theta = s[1]
        xdot = s[2]
        thetadot = s[3]

        s = -np.sin(theta)
        s2 = s ** 2
        c = -np.cos(theta)

        den = mc + mp * s2
        xdotdot = (u + (mp * s * (l * (thetadot ** 2) + g * c))) / den
        thetadotdot = (-u * c - mp * l * (thetadot ** 2) * c * s - (mp + mc) * g * s) / (l * den)

        xdot = min(max(xdot + xdotdot * dt, self.bounds[0, 2]), self.bounds[1, 2])
        thetadot = min(max(thetadot + thetadotdot * dt, self.bounds[0, 3]), self.bounds[1, 3])

        x = min(max(x + xdot * dt, self.bounds[0, 0]), self.bounds[1, 0])
        theta = min(
            max(theta + thetadot * dt + (-s * np.sign(theta) * self.input_pars[0]), self.bounds[0, 1]),
            self.bounds[1, 1],
        )

        s = np.array([x, theta, xdot, thetadot])

        s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
        if self.noise[1]:
            supported_s = np.all(np.equal(self.state_centers[:, 2:], self.state_centers[s_next_i, 2:]), axis=1)
            tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s, 1], loc=s[1], scale=self.noise[1])
            pmf[supported_s] = tmp_pmf
            pmf /= np.sum(pmf)
        else:
            pmf[s_next_i] = 1.0

        return pmf
    def true_dynamics_pmf(self, s_i, a_i):
        # returns a vector of probability masses the same size as state_centers
        s = self.state_centers[s_i,:].copy()
        u = self.action_centers[a_i]
        x = s[0]
        xdot = s[1]
        pmf = np.zeros(self.state_centers.shape[0])

        if 1:
            s[0] = min(max(x+xdot, self.bounds[0,0]), self.bounds[1,0])
            slip = 0
            #if x < .25 <= s[0]:
            if (np.sign(x-.25) != np.sign(s[0]-.25)) or (np.sign(x-.75) != np.sign(s[0]-.75)): # rocks at -.25, .25, .75
                if xdot > 0:
                    slip = max(-self.noise[0], -xdot)
                else:
                    slip = min(self.noise[0], -xdot)
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)) + slip, self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,1] == self.state_centers[s_next_i,1]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,0], loc=s[0], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0
        elif 0:
            slip = 0 if x < .25 else -np.sign(xdot)*self.noise[0]
            s[0] = min(max(x+xdot, self.bounds[0,0]), self.bounds[1,0])
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)) + slip, self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,1] == self.state_centers[s_next_i,1]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,0], loc=s[0], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0
        elif 0:
            #noise on x
            s[0] = min(max(x+xdot + self.noise[0], self.bounds[0,0]), self.bounds[1,0])
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)), self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,1] == self.state_centers[s_next_i,1]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,0], loc=s[0], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0
        elif 0:
            #noise on xdot
            s[0] = min(max(x+xdot, self.bounds[0,0]), self.bounds[1,0])
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)) + self.noise[0], self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,0] == self.state_centers[s_next_i,0]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,1], loc=s[1], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0
        elif 0:
            #noise and slip on x
            slip = 0 if x < -0.5235987755982988 else self.noise[0]*(self.true_pars[0]*np.cos(self.true_pars[1]*x))
            s[0] = min(max(x+xdot + slip, self.bounds[0,0]), self.bounds[1,0])
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)), self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,1] == self.state_centers[s_next_i,1]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,0], loc=s[0], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0
        elif 0:
            #noise and constant slip on xdot
            slip = 0 if x < -0.5235987755982988 else -np.sign(xdot)*self.noise[0]
            s[0] = min(max(x+xdot, self.bounds[0,0]), self.bounds[1,0])
            s[1] = min(max(xdot+0.001*u+(self.true_pars[0]*np.cos(self.true_pars[1]*x)) + slip, self.bounds[0,1]), self.bounds[1,1])
            s_next_i = rl_tools.find_nearest_index_fast(self.dim_centers, s)
            if self.noise[1]:
                supported_s = self.state_centers[:,0] == self.state_centers[s_next_i,0]
                tmp_pmf = scipy.stats.norm.pdf(self.state_centers[supported_s,1], loc=s[1], scale=self.noise[1])
                pmf[supported_s] = tmp_pmf
                pmf /= np.sum(pmf)
            else:
                pmf[s_next_i] = 1.0

        return pmf
 def get_action(self, s):
     return self.action_centers[self.states_to_actions[rl_tools.find_nearest_index_fast(self.dim_centers, s)]]