Example #1
0
    def __init__(
        self,
        env,
        n_bins_obs=10,
        memory_size=100,
        state_preprocess_fn=None,
        state_preprocess_kwargs=None,
    ):
        Wrapper.__init__(self, env)

        if state_preprocess_fn is None:
            assert isinstance(env.observation_space, spaces.Box)
        assert isinstance(env.action_space, spaces.Discrete)

        self.state_preprocess_fn = state_preprocess_fn or identity
        self.state_preprocess_kwargs = state_preprocess_kwargs or {}

        self.memory = TrajectoryMemory(memory_size)
        self.total_visit_counter = DiscreteCounter(self.env.observation_space,
                                                   self.env.action_space,
                                                   n_bins_obs=n_bins_obs)
        self.episode_visit_counter = DiscreteCounter(
            self.env.observation_space,
            self.env.action_space,
            n_bins_obs=n_bins_obs)
        self.current_state = None
        self.curret_step = 0
def test_discrete_env(rate_power):
    env = GridWorld()
    counter = DiscreteCounter(env.observation_space,
                              env.action_space,
                              rate_power=rate_power)

    for N in range(10, 20):
        assert counter.get_n_visited_states() == 0
        assert counter.get_entropy() == 0.0

        for ss in range(env.observation_space.n):
            for aa in range(env.action_space.n):
                for _ in range(N):
                    ns, rr, _, _ = env.sample(ss, aa)
                    counter.update(ss, aa, ns, rr)
                assert counter.N_sa[ss, aa] == N
                assert counter.count(ss, aa) == N
                if rate_power == pytest.approx(1):
                    assert np.allclose(counter.measure(ss, aa), 1.0 / N)
                elif rate_power == pytest.approx(0.5):
                    assert np.allclose(counter.measure(ss, aa),
                                       np.sqrt(1.0 / N))

        assert counter.get_n_visited_states() == env.observation_space.n
        assert np.allclose(counter.get_entropy(),
                           np.log2(env.observation_space.n))

        counter.reset()
Example #3
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        # (s, a) visit counter
        self.N_sa = np.zeros((H, S, A))

        # Value functions
        self.V = np.ones((H + 1, S))
        self.V[H, :] = 0
        self.Q = np.ones((H, S, A))
        self.Q_bar = np.ones((H, S, A))
        for hh in range(self.horizon):
            self.V[hh, :] *= self.horizon - hh
            self.Q[hh, :, :] *= self.horizon - hh
            self.Q_bar[hh, :, :] *= self.horizon - hh

        if self.add_bonus_after_update:
            self.Q *= 0.0

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Example #4
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        # (s, a) visit counter
        self.N_sa = np.zeros((H, S, A))

        # Value functions
        self.V = np.ones((H + 1, S))
        self.V[H, :] = 0
        self.Q = np.ones((H, S, A))
        self.Q_bar = np.ones((H, S, A))
        for hh in range(self.horizon):
            self.V[hh, :] *= (self.horizon - hh)
            self.Q[hh, :, :] *= (self.horizon - hh)
            self.Q_bar[hh, :, :] *= (self.horizon - hh)

        if self.add_bonus_after_update:
            self.Q *= 0.0

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # info
        self._rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())
Example #5
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # Prior transitions
        self.N_sas = self.scale_prior_transition * np.ones(shape_hsas)

        # Prior rewards
        self.M_sa = self.scale_prior_reward * np.ones(shape_hsa + (2, ))

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Example #6
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # (s, a) visit counter
        self.N_sa = np.zeros(shape_hsa)
        # (s, a) bonus
        self.B_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.ones((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and bonus
        if not self.stage_dependent:
            self.B_sa *= self.v_max[0]
            self.V *= self.v_max[0]
        else:
            for hh in range(self.horizon):
                self.B_sa[hh, :, :] = self.v_max[hh]
                self.V[hh, :] = self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # info
        self._rewards = np.zeros(self.n_episodes)

        # update name
        if self.real_time_dp:
            self.name = 'UCBVI-RTDP'

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())
Example #7
0
def test_discrete_env():
    env = GridWorld()
    counter = DiscreteCounter(env.observation_space, env.action_space)

    for N in range(10, 20):
        for ss in range(env.observation_space.n):
            for aa in range(env.action_space.n):
                for _ in range(N):
                    ns, rr, _, _ = env.sample(ss, aa)
                    counter.update(ss, aa, ns, rr)
                assert counter.N_sa[ss, aa] == N
                assert counter.count(ss, aa) == N
        counter.reset()
Example #8
0
def test_continuous_state_env():
    env = MountainCar()
    counter = DiscreteCounter(env.observation_space, env.action_space)

    for N in [10, 20, 30]:
        for _ in range(100):
            ss = env.observation_space.sample()
            aa = env.action_space.sample()
            for _ in range(N):
                ns, rr, _, _ = env.sample(ss, aa)
                counter.update(ss, aa, ns, rr)

            dss = counter.state_discretizer.discretize(ss)
            assert counter.N_sa[dss, aa] == N
            assert counter.count(ss, aa) == N
            counter.reset()
Example #9
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # visit counter
        self.N_sa = np.zeros(shape_hsa)
        # bonus
        self.B_sa = np.zeros((H, S, A))

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.ones((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and bonus
        for hh in range(self.horizon):
            self.B_sa[hh, :, :] = self.v_max[hh]
            self.V[hh, :] = self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # update name
        if self.real_time_dp:
            self.name = "UCBVI-RTDP"
def test_continuous_state_env(rate_power):
    env = MountainCar()
    counter = DiscreteCounter(env.observation_space,
                              env.action_space,
                              rate_power=rate_power)

    for N in [10, 20]:
        for _ in range(50):
            ss = env.observation_space.sample()
            aa = env.action_space.sample()
            for _ in range(N):
                ns, rr, _, _ = env.sample(ss, aa)
                counter.update(ss, aa, ns, rr)

            dss = counter.state_discretizer.discretize(ss)
            assert counter.N_sa[dss, aa] == N
            assert counter.count(ss, aa) == N
            if rate_power == pytest.approx(1):
                assert np.allclose(counter.measure(ss, aa), 1.0 / N)
            elif rate_power == pytest.approx(0.5):
                assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N))
            counter.reset()
Example #11
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # stds prior
        self.std1_sa = self.scale_std_noise * np.ones((H, S, A))
        self.std2_sa = np.ones((H, S, A))
        # visit counter
        self.N_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and variances
        for hh in range(self.horizon):
            self.std2_sa[hh, :, :] *= self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Example #12
0
 def uncertainty_estimator_fn(observation_space, action_space):
     counter = DiscreteCounter(observation_space,
                               action_space,
                               n_bins_obs=20)
     return counter
Example #13
0
 def uncertainty_est_fn(observation_space, action_space):
     return DiscreteCounter(observation_space, action_space)