Beispiel #1
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        # (s, a) visit counter
        self.N_sa = np.zeros((H, S, A))

        # Value functions
        self.V = np.ones((H + 1, S))
        self.V[H, :] = 0
        self.Q = np.ones((H, S, A))
        self.Q_bar = np.ones((H, S, A))
        for hh in range(self.horizon):
            self.V[hh, :] *= self.horizon - hh
            self.Q[hh, :, :] *= self.horizon - hh
            self.Q_bar[hh, :, :] *= self.horizon - hh

        if self.add_bonus_after_update:
            self.Q *= 0.0

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Beispiel #2
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # Prior transitions
        self.N_sas = self.scale_prior_transition * np.ones(shape_hsas)

        # Prior rewards
        self.M_sa = self.scale_prior_reward * np.ones(shape_hsa + (2, ))

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Beispiel #3
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        # (s, a) visit counter
        self.N_sa = np.zeros((H, S, A))

        # Value functions
        self.V = np.ones((H + 1, S))
        self.V[H, :] = 0
        self.Q = np.ones((H, S, A))
        self.Q_bar = np.ones((H, S, A))
        for hh in range(self.horizon):
            self.V[hh, :] *= (self.horizon - hh)
            self.Q[hh, :, :] *= (self.horizon - hh)
            self.Q_bar[hh, :, :] *= (self.horizon - hh)

        if self.add_bonus_after_update:
            self.Q *= 0.0

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # info
        self._rewards = np.zeros(self.n_episodes)

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())
Beispiel #4
0
    def __init__(
        self,
        env,
        n_bins_obs=10,
        memory_size=100,
        state_preprocess_fn=None,
        state_preprocess_kwargs=None,
    ):
        Wrapper.__init__(self, env)

        if state_preprocess_fn is None:
            assert isinstance(env.observation_space, spaces.Box)
        assert isinstance(env.action_space, spaces.Discrete)

        self.state_preprocess_fn = state_preprocess_fn or identity
        self.state_preprocess_kwargs = state_preprocess_kwargs or {}

        self.memory = TrajectoryMemory(memory_size)
        self.total_visit_counter = DiscreteCounter(self.env.observation_space,
                                                   self.env.action_space,
                                                   n_bins_obs=n_bins_obs)
        self.episode_visit_counter = DiscreteCounter(
            self.env.observation_space,
            self.env.action_space,
            n_bins_obs=n_bins_obs)
        self.current_state = None
        self.curret_step = 0
Beispiel #5
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # (s, a) visit counter
        self.N_sa = np.zeros(shape_hsa)
        # (s, a) bonus
        self.B_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.ones((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and bonus
        if not self.stage_dependent:
            self.B_sa *= self.v_max[0]
            self.V *= self.v_max[0]
        else:
            for hh in range(self.horizon):
                self.B_sa[hh, :, :] = self.v_max[hh]
                self.V[hh, :] = self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # info
        self._rewards = np.zeros(self.n_episodes)

        # update name
        if self.real_time_dp:
            self.name = 'UCBVI-RTDP'

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())
Beispiel #6
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # visit counter
        self.N_sa = np.zeros(shape_hsa)
        # bonus
        self.B_sa = np.zeros((H, S, A))

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.ones((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and bonus
        for hh in range(self.horizon):
            self.B_sa[hh, :, :] = self.v_max[hh]
            self.V[hh, :] = self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # update name
        if self.real_time_dp:
            self.name = "UCBVI-RTDP"
Beispiel #7
0
    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # stds prior
        self.std1_sa = self.scale_std_noise * np.ones((H, S, A))
        self.std2_sa = np.ones((H, S, A))
        # visit counter
        self.N_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and variances
        for hh in range(self.horizon):
            self.std2_sa[hh, :, :] *= self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)
Beispiel #8
0
def test_discrete_env():
    env = GridWorld()
    counter = DiscreteCounter(env.observation_space, env.action_space)

    for N in range(10, 20):
        for ss in range(env.observation_space.n):
            for aa in range(env.action_space.n):
                for _ in range(N):
                    ns, rr, _, _ = env.sample(ss, aa)
                    counter.update(ss, aa, ns, rr)
                assert counter.N_sa[ss, aa] == N
                assert counter.count(ss, aa) == N
        counter.reset()
def test_continuous_state_env(rate_power):
    env = MountainCar()
    counter = DiscreteCounter(env.observation_space,
                              env.action_space,
                              rate_power=rate_power)

    for N in [10, 20]:
        for _ in range(50):
            ss = env.observation_space.sample()
            aa = env.action_space.sample()
            for _ in range(N):
                ns, rr, _, _ = env.sample(ss, aa)
                counter.update(ss, aa, ns, rr)

            dss = counter.state_discretizer.discretize(ss)
            assert counter.N_sa[dss, aa] == N
            assert counter.count(ss, aa) == N
            if rate_power == pytest.approx(1):
                assert np.allclose(counter.measure(ss, aa), 1.0 / N)
            elif rate_power == pytest.approx(0.5):
                assert np.allclose(counter.measure(ss, aa), np.sqrt(1.0 / N))
            counter.reset()
Beispiel #10
0
def test_continuous_state_env():
    env = MountainCar()
    counter = DiscreteCounter(env.observation_space, env.action_space)

    for N in [10, 20, 30]:
        for _ in range(100):
            ss = env.observation_space.sample()
            aa = env.action_space.sample()
            for _ in range(N):
                ns, rr, _, _ = env.sample(ss, aa)
                counter.update(ss, aa, ns, rr)

            dss = counter.state_discretizer.discretize(ss)
            assert counter.N_sa[dss, aa] == N
            assert counter.count(ss, aa) == N
            counter.reset()
Beispiel #11
0
class Vis2dWrapper(Wrapper):
    """
    Stores and visualizes the trajectories environments with 2d box observation spaces
    and discrete action spaces.

    Parameters
    ----------
    env: gym.Env
    n_bins_obs : int, default = 10
        Number of intervals to discretize each dimension of the observation space.
        Used to count number of visits.
    memory_size : int, default = 100
        Maximum number of trajectories to keep in memory.
        The most recent ones are kept.
    state_preprocess_fn : callable(state, env, **kwargs)-> np.ndarray, default: None
        Function that converts the state to a 2d array
    state_preprocess_kwargs : dict, default: None
        kwargs for state_preprocess_fn
    """
    def __init__(
        self,
        env,
        n_bins_obs=10,
        memory_size=100,
        state_preprocess_fn=None,
        state_preprocess_kwargs=None,
    ):
        Wrapper.__init__(self, env)

        if state_preprocess_fn is None:
            assert isinstance(env.observation_space, spaces.Box)
        assert isinstance(env.action_space, spaces.Discrete)

        self.state_preprocess_fn = state_preprocess_fn or identity
        self.state_preprocess_kwargs = state_preprocess_kwargs or {}

        self.memory = TrajectoryMemory(memory_size)
        self.total_visit_counter = DiscreteCounter(self.env.observation_space,
                                                   self.env.action_space,
                                                   n_bins_obs=n_bins_obs)
        self.episode_visit_counter = DiscreteCounter(
            self.env.observation_space,
            self.env.action_space,
            n_bins_obs=n_bins_obs)
        self.current_state = None
        self.curret_step = 0

    def reset(self):
        self.current_step = 0
        self.current_state = self.env.reset()
        return self.current_state

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        # initialize new trajectory
        if self.current_step == 0:
            self.memory.end_trajectory()
            self.episode_visit_counter.reset()
        self.current_step += 1
        # update counters
        ss, aa = self.current_state, action
        ns = observation
        self.total_visit_counter.update(ss, aa, ns, reward)
        self.episode_visit_counter.update(ss, aa, ns, reward)
        # store transition
        transition = Transition(
            ss,
            self.state_preprocess_fn(ss, self.env,
                                     **self.state_preprocess_kwargs),
            aa,
            reward,
            self.total_visit_counter.count(ss, aa),
            self.episode_visit_counter.count(ss, aa),
        )
        self.memory.append(transition)
        # update current state
        self.current_state = observation
        return observation, reward, done, info

    def plot_trajectories(
        self,
        fignum=None,
        figsize=(6, 6),
        hide_axis=True,
        show=True,
        video_filename=None,
        colormap_name="cool",
        framerate=15,
        n_skip=1,
        dot_scale_factor=2.5,
        alpha=0.25,
        xlim=None,
        ylim=None,
        dot_size_means="episode_visits",
    ):
        """
        Plot history of trajectories in a scatter plot.
        Colors distinguish recent and old trajectories, the size of the dots represent
        the number of visits to a state.

        If video_filename is given, a video file is saved. Otherwise,
        plot only the final frame.

        Parameters
        ----------
        fignum : str
            Figure name
        figsize : (float, float)
            (width, height) of the image in inches.
        hide_axis : bool
            If True, axes are hidden.
        show : bool
            If True, calls plt.show()
        video_filename : str or None
            If not None, save a video with given filename.
        colormap_name : str, default = 'cool'
            Colormap name.
            See https://matplotlib.org/tutorials/colors/colormaps.html
        framerate : int, default: 15
            Video framerate.
        n_skip : int, default: 1
            Skip period: every n_skip trajectories, one trajectory is plotted.
        dot_scale_factor : double
            Scale factor for scatter plot points.
        alpha : float, default: 0.25
            The alpha blending value, between 0 (transparent) and 1 (opaque).
        xlim: list, default: None
            x plot limits, set to [0, 1] if None
        ylim: list, default: None
            y plot limits, set to [0, 1] if None
        dot_size_means : str, {'episode_visits' or 'total_visits'}, default: 'episode_visits'
            Whether to scale the dot size with the number of visits in an episode
            or the total number of visits during the whole interaction.
        """
        logger.info("Plotting...")

        fignum = fignum or str(self)
        colormap_fn = plt.get_cmap(colormap_name)

        # discretizer
        try:
            discretizer = self.episode_visit_counter.state_discretizer
            epsilon = min(
                discretizer._bins[0][1] - discretizer._bins[0][0],
                discretizer._bins[1][1] - discretizer._bins[1][0],
            )
        except Exception:
            epsilon = 0.01

        # figure setup
        xlim = xlim or [0.0, 1.0]
        ylim = ylim or [0.0, 1.0]

        fig = plt.figure(fignum, figsize=figsize)
        fig.clf()
        canvas = FigureCanvas(fig)
        images = []
        ax = fig.gca()

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        if hide_axis:
            ax.set_axis_off()

        # scatter plot
        indices = np.arange(self.memory.n_trajectories)[::n_skip]

        for idx in indices:
            traj = self.memory.trajectories[idx]
            color_time_intensity = (idx + 1) / self.memory.n_trajectories
            color = colormap_fn(color_time_intensity)

            states = np.array([traj[ii].state for ii in range(len(traj))])

            if dot_size_means == "episode_visits":
                sizes = np.array(
                    [traj[ii].n_episode_visits for ii in range(len(traj))])
            elif dot_size_means == "total_visits":
                raw_states = [traj[ii].raw_state for ii in range(len(traj))]
                sizes = np.array([
                    np.sum([
                        self.total_visit_counter.count(ss, aa)
                        for aa in range(self.env.action_space.n)
                    ]) for ss in raw_states
                ])
            else:
                raise ValueError()

            sizes = 1 + sizes
            sizes = (dot_scale_factor**2) * 100 * epsilon * sizes / sizes.max()

            ax.scatter(x=states[:, 0],
                       y=states[:, 1],
                       color=color,
                       s=sizes,
                       alpha=alpha)
            plt.tight_layout()

            if video_filename is not None:
                canvas.draw()
                image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(),
                                                dtype=np.uint8)
                image_from_plot = image_from_plot.reshape(
                    fig.canvas.get_width_height()[::-1] + (3, ))
                images.append(image_from_plot)

        if video_filename is not None:
            logger.info("... writing video ...")
            video_write(video_filename, images, framerate=framerate)

        logger.info("... done!")

        if show:
            plt.show()

    def plot_trajectory_actions(
        self,
        fignum=None,
        figsize=(8, 6),
        n_traj_to_show=10,
        hide_axis=True,
        show=True,
        video_filename=None,
        colormap_name="Paired",
        framerate=15,
        n_skip=1,
        dot_scale_factor=2.5,
        alpha=1.0,
        action_description=None,
        xlim=None,
        ylim=None,
    ):
        """
        Plot actions (one action = one color) chosen in recent trajectories.

        If video_filename is given, a video file is saved showing the evolution of
        the actions taken in past trajectories.

        Parameters
        ----------
        fignum : str
            Figure name
        figsize : (float, float)
            (width, height) of the image in inches.
        n_traj_to_show : int
            Number of trajectories to visualize in each frame.
        hide_axis : bool
            If True, axes are hidden.
        show : bool
            If True, calls plt.show()
        video_filename : str or None
            If not None, save a video with given filename.
        colormap_name : str, default = 'tab20b'
            Colormap name.
            See https://matplotlib.org/tutorials/colors/colormaps.html
        framerate : int, default: 15
            Video framerate.
        n_skip : int, default: 1
            Skip period: every n_skip trajectories, one trajectory is plotted.
        dot_scale_factor : double
            Scale factor for scatter plot points.
        alpha : float, default: 1.0
            The alpha blending value, between 0 (transparent) and 1 (opaque).
        action_description : list or None (optional)
            List (of strings) containing a description of each action.
            For instance, ['left', 'right', 'up', 'down'].
        xlim: list, default: None
            x plot limits, set to [0, 1] if None
        ylim: list, default: None
            y plot limits, set to [0, 1] if None
        """
        logger.info("Plotting...")

        fignum = fignum or (str(self) + "-actions")
        colormap_fn = plt.get_cmap(colormap_name)
        action_description = action_description or list(
            range(self.env.action_space.n))

        # discretizer
        try:
            discretizer = self.episode_visit_counter.state_discretizer
            epsilon = min(
                discretizer._bins[0][1] - discretizer._bins[0][0],
                discretizer._bins[1][1] - discretizer._bins[1][0],
            )
        except Exception:
            epsilon = 0.01

        # figure setup
        xlim = xlim or [0.0, 1.0]
        ylim = ylim or [0.0, 1.0]

        # indices to visualize
        if video_filename is None:
            indices = [self.memory.n_trajectories - 1]
        else:
            indices = np.arange(self.memory.n_trajectories)[::n_skip]

        # images for video
        images = []

        # for idx in indices:
        for init_idx in indices:
            idx_set = range(max(0, init_idx - n_traj_to_show + 1),
                            init_idx + 1)
            # clear before showing new trajectories
            fig = plt.figure(fignum, figsize=figsize)
            fig.clf()
            canvas = FigureCanvas(fig)
            ax = fig.gca()

            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

            if hide_axis:
                ax.set_axis_off()

            for idx in idx_set:
                traj = self.memory.trajectories[idx]

                states = np.array([traj[ii].state for ii in range(len(traj))])
                actions = np.array(
                    [traj[ii].action for ii in range(len(traj))])

                sizes = (dot_scale_factor**2) * 750 * epsilon

                for aa in range(self.env.action_space.n):
                    states_aa = states[actions == aa]
                    color = colormap_fn(aa / self.env.action_space.n)
                    ax.scatter(
                        x=states_aa[:, 0],
                        y=states_aa[:, 1],
                        color=color,
                        s=sizes,
                        alpha=alpha,
                        label=f"action = {action_description[aa]}",
                    )

            # for unique legend entries, source: https://stackoverflow.com/a/57600060
            plt.legend(
                *[
                    *zip(*{
                        l: h
                        for h, l in zip(*ax.get_legend_handles_labels())
                    }.items())
                ][::-1],
                loc="upper left",
                bbox_to_anchor=(1.00, 1.00),
            )
            plt.tight_layout()

            if video_filename is not None:
                canvas.draw()
                image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(),
                                                dtype=np.uint8)
                image_from_plot = image_from_plot.reshape(
                    fig.canvas.get_width_height()[::-1] + (3, ))
                images.append(image_from_plot)

        if video_filename is not None:
            logger.info("... writing video ...")
            video_write(video_filename, images, framerate=framerate)

        logger.info("... done!")

        if show:
            plt.show()
Beispiel #12
0
class UCBVIAgent(IncrementalAgent):
    """
    UCBVI [1]_ with custom exploration bonus.

    Notes
    -----
    The recommended policy after all the episodes is computed without
    exploration bonuses.

    Parameters
    ----------
    env : gym.Env
        Environment with discrete states and actions.
    n_episodes : int
        Number of episodes.
    gamma : double, default: 1.0
        Discount factor in [0, 1]. If gamma is 1.0, the problem is set to
        be finite-horizon.
    horizon : int
        Horizon of the objective function. If None and gamma<1, set to
        1/(1-gamma).
    bonus_scale_factor : double, default: 1.0
        Constant by which to multiply the exploration bonus, controls
        the level of exploration.
    bonus_type : {"simplified_bernstein"}
        Type of exploration bonus. Currently, only "simplified_bernstein"
        is implemented. If `reward_free` is true, this parameter is ignored
        and the algorithm uses 1/n bonuses.
    reward_free : bool, default: False
        If true, ignores rewards and uses only 1/n bonuses.
    stage_dependent : bool, default: False
        If true, assume that transitions and rewards can change with the stage h.
    real_time_dp : bool, default: False
        If true, uses real-time dynamic programming [2]_ instead of full backward induction
        for the sampling policy.

    References
    ----------
    .. [1] Azar et al., 2017
        Minimax Regret Bounds for Reinforcement Learning
        https://arxiv.org/abs/1703.05449

    .. [2] Efroni, Yonathan, et al.
          Tight regret bounds for model-based reinforcement learning with greedy policies.
          Advances in Neural Information Processing Systems. 2019.
          https://papers.nips.cc/paper/2019/file/25caef3a545a1fff2ff4055484f0e758-Paper.pdf
    """
    name = "UCBVI"

    def __init__(self,
                 env,
                 n_episodes=1000,
                 gamma=1.0,
                 horizon=100,
                 bonus_scale_factor=1.0,
                 bonus_type="simplified_bernstein",
                 reward_free=False,
                 stage_dependent=False,
                 real_time_dp=False,
                 **kwargs):
        # init base class
        IncrementalAgent.__init__(self, env, **kwargs)

        self.n_episodes = n_episodes
        self.gamma = gamma
        self.horizon = horizon
        self.bonus_scale_factor = bonus_scale_factor
        self.bonus_type = bonus_type
        self.reward_free = reward_free
        self.stage_dependent = stage_dependent
        self.real_time_dp = real_time_dp

        # check environment
        assert isinstance(self.env.observation_space, spaces.Discrete)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # other checks
        assert gamma >= 0 and gamma <= 1.0
        if self.horizon is None:
            assert gamma < 1.0, \
                "If no horizon is given, gamma must be smaller than 1."
            self.horizon = int(np.ceil(1.0 / (1.0 - gamma)))

        # maximum value
        r_range = self.env.reward_range[1] - self.env.reward_range[0]
        if r_range == np.inf or r_range == 0.0:
            logger.warning(
                "{}: Reward range is  zero or infinity. ".format(self.name) +
                "Setting it to 1.")
            r_range = 1.0

        self.v_max = np.zeros(self.horizon)
        self.v_max[-1] = r_range
        for hh in reversed(range(self.horizon - 1)):
            self.v_max[hh] = r_range + self.gamma * self.v_max[hh + 1]

        # initialize
        self.reset()

    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # (s, a) visit counter
        self.N_sa = np.zeros(shape_hsa)
        # (s, a) bonus
        self.B_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.ones((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and bonus
        if not self.stage_dependent:
            self.B_sa *= self.v_max[0]
            self.V *= self.v_max[0]
        else:
            for hh in range(self.horizon):
                self.B_sa[hh, :, :] = self.v_max[hh]
                self.V[hh, :] = self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

        # info
        self._rewards = np.zeros(self.n_episodes)

        # update name
        if self.real_time_dp:
            self.name = 'UCBVI-RTDP'

        # default writer
        self.writer = PeriodicWriter(self.name,
                                     log_every=5 * logger.getEffectiveLevel())

    def policy(self, state, hh=0, **kwargs):
        """ Recommended policy. """
        assert self.Q_policy is not None
        return self.Q_policy[hh, state, :].argmax()

    def _get_action(self, state, hh=0):
        """ Sampling policy. """
        if not self.real_time_dp:
            assert self.Q is not None
            return self.Q[hh, state, :].argmax()
        else:
            if self.stage_dependent:
                update_fn = update_value_and_get_action_sd
            else:
                update_fn = update_value_and_get_action
            return update_fn(
                state,
                hh,
                self.V,
                self.R_hat,
                self.P_hat,
                self.B_sa,
                self.gamma,
                self.v_max,
            )

    def _compute_bonus(self, n, hh):
        # reward-free
        if self.reward_free:
            bonus = 1.0 / n
            return bonus

        # not reward-free
        if self.bonus_type == "simplified_bernstein":
            bonus = self.bonus_scale_factor * np.sqrt(
                1.0 / n) + self.v_max[hh] / n
            bonus = min(bonus, self.v_max[hh])
            return bonus
        else:
            raise ValueError("Error: bonus type {} not implemented".format(
                self.bonus_type))

    def _update(self, state, action, next_state, reward, hh):
        if self.stage_dependent:
            self.N_sa[hh, state, action] += 1

            nn = self.N_sa[hh, state, action]
            prev_r = self.R_hat[hh, state, action]
            prev_p = self.P_hat[hh, state, action, :]

            self.R_hat[hh, state,
                       action] = (1.0 - 1.0 / nn) * prev_r + reward * 1.0 / nn

            self.P_hat[hh, state, action, :] = (1.0 - 1.0 / nn) * prev_p
            self.P_hat[hh, state, action, next_state] += 1.0 / nn

            self.B_sa[hh, state, action] = self._compute_bonus(nn, hh)

        else:
            self.N_sa[state, action] += 1

            nn = self.N_sa[state, action]
            prev_r = self.R_hat[state, action]
            prev_p = self.P_hat[state, action, :]

            self.R_hat[state,
                       action] = (1.0 - 1.0 / nn) * prev_r + reward * 1.0 / nn

            self.P_hat[state, action, :] = (1.0 - 1.0 / nn) * prev_p
            self.P_hat[state, action, next_state] += 1.0 / nn

            self.B_sa[state, action] = self._compute_bonus(nn, 0)

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for hh in range(self.horizon):
            action = self._get_action(state, hh)
            next_state, reward, done, _ = self.env.step(action)
            episode_rewards += reward  # used for logging only

            self.counter.update(state, action)

            if self.reward_free:
                reward = 0.0  # set to zero before update if reward_free

            self._update(state, action, next_state, reward, hh)

            state = next_state
            if done:
                break

        # run backward induction
        if not self.real_time_dp:
            if self.stage_dependent:
                backward_induction_sd(self.Q, self.V, self.R_hat + self.B_sa,
                                      self.P_hat, self.gamma, self.v_max[0])
            else:
                backward_induction_in_place(self.Q, self.V,
                                            self.R_hat + self.B_sa, self.P_hat,
                                            self.horizon, self.gamma,
                                            self.v_max[0])

        # update info
        ep = self.episode
        self._rewards[ep] = episode_rewards
        self.episode += 1

        # writer
        if self.writer is not None:
            self.writer.add_scalar("ep reward", episode_rewards, self.episode)
            self.writer.add_scalar("n_visited_states",
                                   self.counter.get_n_visited_states(),
                                   self.episode)

        # return sum of rewards collected in the episode
        return episode_rewards

    def partial_fit(self, fraction, **kwargs):
        assert 0.0 < fraction <= 1.0
        n_episodes_to_run = int(np.ceil(fraction * self.n_episodes))
        count = 0
        while count < n_episodes_to_run and self.episode < self.n_episodes:
            self._run_episode()
            count += 1

        # compute Q function for the recommended policy
        if self.stage_dependent:
            backward_induction_sd(self.Q_policy, self.V_policy, self.R_hat,
                                  self.P_hat, self.gamma, self.v_max[0])
        else:
            backward_induction_in_place(self.Q_policy, self.V_policy,
                                        self.R_hat, self.P_hat, self.horizon,
                                        self.gamma, self.v_max[0])

        info = {
            "n_episodes": self.episode,
            "episode_rewards": self._rewards[:self.episode]
        }
        return info
Beispiel #13
0
class OptQLAgent(AgentWithSimplePolicy):
    """
    Optimistic Q-Learning [1]_ with custom exploration bonuses.

    Parameters
    ----------
    env : gym.Env
        Environment with discrete states and actions.
    gamma : double, default: 1.0
        Discount factor in [0, 1].
    horizon : int
        Horizon of the objective function.
    bonus_scale_factor : double, default: 1.0
        Constant by which to multiply the exploration bonus, controls
        the level of exploration.
    bonus_type : {"simplified_bernstein"}
        Type of exploration bonus. Currently, only "simplified_bernstein"
        is implemented.
    add_bonus_after_update : bool, default: False
        If True, add bonus to the Q function after performing the update,
        instead of adding it to the update target.

    References
    ----------
    .. [1] Jin et al., 2018
           Is Q-Learning Provably Efficient?
           https://arxiv.org/abs/1807.03765
    """

    name = "OptQL"

    def __init__(self,
                 env,
                 gamma=1.0,
                 horizon=100,
                 bonus_scale_factor=1.0,
                 bonus_type="simplified_bernstein",
                 add_bonus_after_update=False,
                 **kwargs):
        # init base class
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.gamma = gamma
        self.horizon = horizon
        self.bonus_scale_factor = bonus_scale_factor
        self.bonus_type = bonus_type
        self.add_bonus_after_update = add_bonus_after_update

        # check environment
        assert isinstance(self.env.observation_space, spaces.Discrete)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # maximum value
        r_range = self.env.reward_range[1] - self.env.reward_range[0]
        if r_range == np.inf or r_range == 0.0:
            logger.warning(
                "{}: Reward range is  zero or infinity. ".format(self.name) +
                "Setting it to 1.")
            r_range = 1.0

        self.v_max = np.zeros(self.horizon)
        self.v_max[-1] = r_range
        for hh in reversed(range(self.horizon - 1)):
            self.v_max[hh] = r_range + self.gamma * self.v_max[hh + 1]

        # initialize
        self.reset()

    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        # (s, a) visit counter
        self.N_sa = np.zeros((H, S, A))

        # Value functions
        self.V = np.ones((H + 1, S))
        self.V[H, :] = 0
        self.Q = np.ones((H, S, A))
        self.Q_bar = np.ones((H, S, A))
        for hh in range(self.horizon):
            self.V[hh, :] *= self.horizon - hh
            self.Q[hh, :, :] *= self.horizon - hh
            self.Q_bar[hh, :, :] *= self.horizon - hh

        if self.add_bonus_after_update:
            self.Q *= 0.0

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

    def policy(self, observation):
        """Recommended policy."""
        state = observation
        return self.Q_bar[0, state, :].argmax()

    def _get_action(self, state, hh=0):
        """Sampling policy."""
        return self.Q_bar[hh, state, :].argmax()

    def _compute_bonus(self, n, hh):
        if self.bonus_type == "simplified_bernstein":
            bonus = self.bonus_scale_factor * np.sqrt(
                1.0 / n) + self.v_max[hh] / n
            bonus = min(bonus, self.v_max[hh])
            return bonus
        else:
            raise ValueError("Error: bonus type {} not implemented".format(
                self.bonus_type))

    def _update(self, state, action, next_state, reward, hh):
        self.N_sa[hh, state, action] += 1
        nn = self.N_sa[hh, state, action]

        # learning rate
        alpha = (self.horizon + 1.0) / (self.horizon + nn)
        bonus = self._compute_bonus(nn, hh)

        # bonus in the update
        if not self.add_bonus_after_update:
            target = reward + bonus + self.gamma * self.V[hh + 1, next_state]
            self.Q[hh, state,
                   action] = (1 - alpha) * self.Q[hh, state,
                                                  action] + alpha * target
            self.V[hh, state] = min(self.v_max[hh], self.Q[hh, state, :].max())
            self.Q_bar[hh, state, action] = self.Q[hh, state, action]
        # bonus outside the update
        else:
            target = reward + self.gamma * self.V[hh + 1,
                                                  next_state]  # bonus not here
            self.Q[hh, state,
                   action] = (1 - alpha) * self.Q[hh, state,
                                                  action] + alpha * target
            self.Q_bar[hh, state, action] = (self.Q[hh, state, action] + bonus
                                             )  # bonus here
            self.V[hh, state] = min(self.v_max[hh], self.Q_bar[hh,
                                                               state, :].max())

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for hh in range(self.horizon):
            action = self._get_action(state, hh)
            next_state, reward, done, _ = self.env.step(action)
            episode_rewards += reward  # used for logging only

            self.counter.update(state, action)

            self._update(state, action, next_state, reward, hh)

            state = next_state
            if done:
                break

        # update info
        self.episode += 1

        # writer
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)
            self.writer.add_scalar("n_visited_states",
                                   self.counter.get_n_visited_states(),
                                   self.episode)

        # return sum of rewards collected in the episode
        return episode_rewards

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1
def test_discrete_env(rate_power):
    env = GridWorld()
    counter = DiscreteCounter(env.observation_space,
                              env.action_space,
                              rate_power=rate_power)

    for N in range(10, 20):
        assert counter.get_n_visited_states() == 0
        assert counter.get_entropy() == 0.0

        for ss in range(env.observation_space.n):
            for aa in range(env.action_space.n):
                for _ in range(N):
                    ns, rr, _, _ = env.sample(ss, aa)
                    counter.update(ss, aa, ns, rr)
                assert counter.N_sa[ss, aa] == N
                assert counter.count(ss, aa) == N
                if rate_power == pytest.approx(1):
                    assert np.allclose(counter.measure(ss, aa), 1.0 / N)
                elif rate_power == pytest.approx(0.5):
                    assert np.allclose(counter.measure(ss, aa),
                                       np.sqrt(1.0 / N))

        assert counter.get_n_visited_states() == env.observation_space.n
        assert np.allclose(counter.get_entropy(),
                           np.log2(env.observation_space.n))

        counter.reset()
Beispiel #15
0
class RLSVIAgent(AgentWithSimplePolicy):
    """
    RLSVI algorithm from [1,2] with Gaussian noise.

    Notes
    -----
    The recommended policy after all the episodes is computed with the empirical
    MDP.
    The std of the noise is of the form:
    scale/sqrt(n)+ V_max/n
    as for simplified Bernstein bonuses.

    Parameters
    ----------
    env : gym.Env
        Environment with discrete states and actions.
    gamma : double, default: 1.0
        Discount factor in [0, 1]. If gamma is 1.0, the problem is set to
        be finite-horizon.
    horizon : int
        Horizon of the objective function. If None and gamma<1, set to
        1/(1-gamma).
    scale_std_noise : double, delfault: 1.0
        scale the std of the noise. At step h the std is
        scale_std_noise/sqrt(n)+(H-h+1)/n
    reward_free : bool, default: False
        If true, ignores rewards.
    stage_dependent : bool, default: False
        If true, assume that transitions and rewards can change with the stage h.

    References
    ----------
    .. [1] Osband et al., 2014
        Generalization and Exploration via Randomized Value Functions
        https://arxiv.org/abs/1402.0635

    .. [2] Russo, 2019
        Worst-Case Regret Bounds for Exploration via Randomized Value Functions
        https://arxiv.org/abs/1906.02870

    """

    name = "RLSVI"

    def __init__(self,
                 env,
                 gamma=1.0,
                 horizon=100,
                 scale_std_noise=1.0,
                 reward_free=False,
                 stage_dependent=False,
                 **kwargs):
        # init base class
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.gamma = gamma
        self.horizon = horizon
        self.scale_std_noise = scale_std_noise
        self.reward_free = reward_free
        self.stage_dependent = stage_dependent

        # check environment
        assert isinstance(self.env.observation_space, spaces.Discrete)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # other checks
        assert gamma >= 0 and gamma <= 1.0
        if self.horizon is None:
            assert gamma < 1.0, "If no horizon is given, gamma must be smaller than 1."
            self.horizon = int(np.ceil(1.0 / (1.0 - gamma)))

        # maximum value
        r_range = self.env.reward_range[1] - self.env.reward_range[0]
        if r_range == np.inf or r_range == 0.0:
            logger.warning(
                "{}: Reward range is  zero or infinity. ".format(self.name) +
                "Setting it to 1.")
            r_range = 1.0

        self.v_max = np.zeros(self.horizon)
        self.v_max[-1] = r_range
        for hh in reversed(range(self.horizon - 1)):
            self.v_max[hh] = r_range + self.gamma * self.v_max[hh + 1]

        # initialize
        self.reset()

    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # stds prior
        self.std1_sa = self.scale_std_noise * np.ones((H, S, A))
        self.std2_sa = np.ones((H, S, A))
        # visit counter
        self.N_sa = np.ones(shape_hsa)

        # MDP estimator
        self.R_hat = np.zeros(shape_hsa)
        self.P_hat = np.ones(shape_hsas) * 1.0 / S

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # Init V and variances
        for hh in range(self.horizon):
            self.std2_sa[hh, :, :] *= self.v_max[hh]

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

    def policy(self, observation):
        state = observation
        assert self.Q_policy is not None
        return self.Q_policy[0, state, :].argmax()

    def _get_action(self, state, hh=0):
        """Sampling policy."""
        assert self.Q is not None
        return self.Q[hh, state, :].argmax()

    def _update(self, state, action, next_state, reward, hh):
        if self.stage_dependent:
            self.N_sa[hh, state, action] += 1

            nn = self.N_sa[hh, state, action]
            prev_r = self.R_hat[hh, state, action]
            prev_p = self.P_hat[hh, state, action, :]

            self.R_hat[hh, state,
                       action] = (1.0 - 1.0 / nn) * prev_r + reward * 1.0 / nn

            self.P_hat[hh, state, action, :] = (1.0 - 1.0 / nn) * prev_p
            self.P_hat[hh, state, action, next_state] += 1.0 / nn

        else:
            self.N_sa[state, action] += 1

            nn = self.N_sa[state, action]
            prev_r = self.R_hat[state, action]
            prev_p = self.P_hat[state, action, :]

            self.R_hat[state,
                       action] = (1.0 - 1.0 / nn) * prev_r + reward * 1.0 / nn

            self.P_hat[state, action, :] = (1.0 - 1.0 / nn) * prev_p
            self.P_hat[state, action, next_state] += 1.0 / nn

    def _run_episode(self):
        # interact for H steps
        episode_rewards = 0
        # stds scale/sqrt(n)+(H-h+1)/n
        std_sa = self.std1_sa / np.sqrt(self.N_sa) + self.std2_sa / self.N_sa
        noise_sa = self.rng.normal(self.R_hat, std_sa)
        # run backward noisy induction
        if self.stage_dependent:
            backward_induction_sd(
                self.Q,
                self.V,
                self.R_hat + noise_sa,
                self.P_hat,
                self.gamma,
                self.v_max[0],
            )
        else:
            backward_induction_reward_sd(
                self.Q,
                self.V,
                self.R_hat + noise_sa,
                self.P_hat,
                self.gamma,
                self.v_max[0],
            )

        state = self.env.reset()
        for hh in range(self.horizon):
            action = self._get_action(state, hh)
            next_state, reward, done, _ = self.env.step(action)
            episode_rewards += reward  # used for logging only

            self.counter.update(state, action)

            if self.reward_free:
                reward = 0.0  # set to zero before update if reward_free

            self._update(state, action, next_state, reward, hh)

            state = next_state
            if done:
                break

        # update info
        self.episode += 1

        # writer
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)
            self.writer.add_scalar("n_visited_states",
                                   self.counter.get_n_visited_states(),
                                   self.episode)

        # return sum of rewards collected in the episode
        return episode_rewards

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

        # compute Q function for the recommended policy
        if self.stage_dependent:
            backward_induction_sd(
                self.Q_policy,
                self.V_policy,
                self.R_hat,
                self.P_hat,
                self.gamma,
                self.v_max[0],
            )
        else:
            backward_induction_in_place(
                self.Q_policy,
                self.V_policy,
                self.R_hat,
                self.P_hat,
                self.horizon,
                self.gamma,
                self.v_max[0],
            )
Beispiel #16
0
class PSRLAgent(AgentWithSimplePolicy):
    """
    PSRL algorithm from [1] with beta prior for the "Bernoullized" rewards
    (instead of Gaussian-gamma prior).

    Notes
    -----
    The recommended policy after all the episodes is computed without
    exploration bonuses.

    Parameters
    ----------
    env : gym.Env
        Environment with discrete states and actions.
    gamma : double, default: 1.0
        Discount factor in [0, 1]. If gamma is 1.0, the problem is set to
        be finite-horizon.
    horizon : int
        Horizon of the objective function. If None and gamma<1, set to
        1/(1-gamma).
    scale_prior_reward : double, delfault: 1.0
        scale of the Beta (uniform) prior,
        i.e prior is Beta(scale_prior_reward*(1,1))
    scale_prior_transition : double, default: 1/number of state
        scale of the (uniform) Dirichlet prior,
        i.e prior is Dirichlet(scale_prior_transition*(1,...,1))
    bernoullized_reward: bool, default: True
        If true the rewards are Bernoullized
    reward_free : bool, default: False
        If true, ignores rewards and uses only 1/n bonuses.
    stage_dependent : bool, default: False
        If true, assume that transitions and rewards can change with the stage h.

    References
    ----------
    .. [1] Osband et al., 2013
        (More) Efficient Reinforcement Learning via Posterior Sampling
        https://arxiv.org/abs/1306.0940

    """

    name = "PSRL"

    def __init__(self,
                 env,
                 gamma=1.0,
                 horizon=100,
                 scale_prior_reward=1,
                 scale_prior_transition=None,
                 bernoullized_reward=True,
                 reward_free=False,
                 stage_dependent=False,
                 **kwargs):
        # init base class
        AgentWithSimplePolicy.__init__(self, env, **kwargs)

        self.gamma = gamma
        self.horizon = horizon
        self.scale_prior_reward = scale_prior_reward
        self.scale_prior_transition = scale_prior_transition
        if scale_prior_transition is None:
            self.scale_prior_transition = 1.0 / self.env.observation_space.n
        self.bernoullized_reward = bernoullized_reward
        self.reward_free = reward_free
        self.stage_dependent = stage_dependent

        # check environment
        assert isinstance(self.env.observation_space, spaces.Discrete)
        assert isinstance(self.env.action_space, spaces.Discrete)

        # other checks
        assert gamma >= 0 and gamma <= 1.0
        if self.horizon is None:
            assert gamma < 1.0, "If no horizon is given, gamma must be smaller than 1."
            self.horizon = int(np.ceil(1.0 / (1.0 - gamma)))

        # maximum value
        r_range = self.env.reward_range[1] - self.env.reward_range[0]
        if r_range == np.inf or r_range == 0.0:
            logger.warning(
                "{}: Reward range is  zero or infinity. ".format(self.name) +
                "Setting it to 1.")
            r_range = 1.0

        self.v_max = np.zeros(self.horizon)
        self.v_max[-1] = r_range
        for hh in reversed(range(self.horizon - 1)):
            self.v_max[hh] = r_range + self.gamma * self.v_max[hh + 1]

        # initialize
        self.reset()

    def reset(self, **kwargs):
        H = self.horizon
        S = self.env.observation_space.n
        A = self.env.action_space.n

        if self.stage_dependent:
            shape_hsa = (H, S, A)
            shape_hsas = (H, S, A, S)
        else:
            shape_hsa = (S, A)
            shape_hsas = (S, A, S)

        # Prior transitions
        self.N_sas = self.scale_prior_transition * np.ones(shape_hsas)

        # Prior rewards
        self.M_sa = self.scale_prior_reward * np.ones(shape_hsa + (2, ))

        # Value functions
        self.V = np.zeros((H, S))
        self.Q = np.zeros((H, S, A))
        # for rec. policy
        self.V_policy = np.zeros((H, S))
        self.Q_policy = np.zeros((H, S, A))

        # ep counter
        self.episode = 0

        # useful object to compute total number of visited states & entropy of visited states
        self.counter = DiscreteCounter(self.env.observation_space,
                                       self.env.action_space)

    def policy(self, observation):
        state = observation
        assert self.Q_policy is not None
        return self.Q_policy[0, state, :].argmax()

    def _get_action(self, state, hh=0):
        """Sampling policy."""
        assert self.Q is not None
        return self.Q[hh, state, :].argmax()

    def _update(self, state, action, next_state, reward, hh):
        bern_reward = reward
        if self.bernoullized_reward:
            bern_reward = self.rng.binomial(1, reward)
        # update posterior
        if self.stage_dependent:
            self.N_sas[hh, state, action, next_state] += 1
            self.M_sa[hh, state, action, 0] += bern_reward
            self.M_sa[hh, state, action, 1] += 1 - bern_reward

        else:
            self.N_sas[state, action, next_state] += 1
            self.M_sa[state, action, 0] += bern_reward
            self.M_sa[state, action, 1] += 1 - bern_reward

    def _run_episode(self):
        # sample reward and transitions from posterior
        self.R_sample = self.rng.beta(self.M_sa[..., 0], self.M_sa[..., 1])
        self.P_sample = self.rng.gamma(self.N_sas)
        self.P_sample = self.P_sample / self.P_sample.sum(-1, keepdims=True)
        # run backward induction
        if self.stage_dependent:
            backward_induction_sd(self.Q, self.V, self.R_sample, self.P_sample,
                                  self.gamma, self.v_max[0])
        else:
            backward_induction_in_place(
                self.Q,
                self.V,
                self.R_sample,
                self.P_sample,
                self.horizon,
                self.gamma,
                self.v_max[0],
            )
        # interact for H steps
        episode_rewards = 0
        state = self.env.reset()
        for hh in range(self.horizon):
            action = self._get_action(state, hh)
            next_state, reward, done, _ = self.env.step(action)
            episode_rewards += reward  # used for logging only

            self.counter.update(state, action)

            if self.reward_free:
                reward = 0.0  # set to zero before update if reward_free

            self._update(state, action, next_state, reward, hh)

            state = next_state
            if done:
                break

        # update info
        self.episode += 1

        # writer
        if self.writer is not None:
            self.writer.add_scalar("episode_rewards", episode_rewards,
                                   self.episode)
            self.writer.add_scalar("n_visited_states",
                                   self.counter.get_n_visited_states(),
                                   self.episode)

        # return sum of rewards collected in the episode
        return episode_rewards

    def fit(self, budget: int, **kwargs):
        del kwargs
        n_episodes_to_run = budget
        count = 0
        while count < n_episodes_to_run:
            self._run_episode()
            count += 1

        # compute Q function for the recommended policy
        R_hat = self.M_sa[..., 0] / (self.M_sa[..., 0] + self.M_sa[..., 1])
        P_hat = self.N_sas / self.N_sas.sum(-1, keepdims=True)
        if self.stage_dependent:
            backward_induction_sd(self.Q_policy, self.V_policy, R_hat, P_hat,
                                  self.gamma, self.v_max[0])
        else:
            backward_induction_in_place(
                self.Q_policy,
                self.V_policy,
                R_hat,
                P_hat,
                self.horizon,
                self.gamma,
                self.v_max[0],
            )
Beispiel #17
0
 def uncertainty_estimator_fn(observation_space, action_space):
     counter = DiscreteCounter(observation_space,
                               action_space,
                               n_bins_obs=20)
     return counter
Beispiel #18
0
 def uncertainty_est_fn(observation_space, action_space):
     return DiscreteCounter(observation_space, action_space)