def _play(self, play_context: core.PlayContext, callbacks: Union[List[core.AgentCallback], core.AgentCallback, None], default_plots: Optional[bool]): """Plays episodes with the current policy according to play_context. Hints: o updates rewards in play_context Args: play_context: specifies the num of episodes to play callbacks: list of callbacks called during the play of the episodes default_plots: if set adds a set of default callbacks (plot.State, plot.Rewards, plot.Loss,...). if None default callbacks are only added if the callbacks list is empty Returns: play_context containing the actions taken and the rewards received during training """ assert play_context, "play_context not set." if callbacks is None: callbacks = [] if not isinstance(callbacks, list): assert isinstance(callbacks, core.AgentCallback), "callback not an AgentCallback or a list thereof." callbacks = [callbacks] callbacks = self._prepare_callbacks(callbacks, default_plots, [plot.Steps(), plot.Rewards()]) self._backend_agent.play(play_context=play_context, callbacks=callbacks) return play_context
def train(self, train_context: core.TrainContext, callbacks: Union[List[core.AgentCallback], core.AgentCallback, None], default_plots: Optional[bool]): """Trains a new model using the gym environment passed during instantiation. Args: callbacks: list of callbacks called during the training and evaluation train_context: training configuration to be used (num_iterations,num_episodes_per_iteration,...) default_plots: if set adds a set of default callbacks (plot.State, plot.Rewards, plot.Loss,...). if None default callbacks are only added if the callbacks list is empty """ assert train_context, "train_context not set." if callbacks is None: callbacks = [] if not isinstance(callbacks, list): assert isinstance( callbacks, core.AgentCallback ), "callback not a AgentCallback or a list thereof." callbacks = [callbacks] callbacks = self._prepare_callbacks( callbacks, default_plots, [plot.Loss(), plot.Steps(), plot.Rewards()]) self._backend_agent.train(train_context=train_context, callbacks=callbacks)
def play(self, callbacks: Union[List[core.AgentCallback], core.AgentCallback, None] = None, num_episodes: int = 1, max_steps_per_episode: int = 1000, play_context: core.PlayContext = None, default_plots: bool = None): """Plays num_episodes with the current policy. Args: callbacks: list of callbacks called during each episode play num_episodes: number of episodes to play max_steps_per_episode: max steps per episode play_context: play configuration to be used. If set override all other play context arguments default_plots: if set addes a set of default callbacks (plot.State, plot.Rewards, ...) Returns: play_context containg the actions taken and the rewards received during training """ assert self._backend_agent._agent_context._is_policy_trained, "No trained policy available. Call train() first." if play_context is None: play_context = core.PlayContext() play_context.max_steps_per_episode = max_steps_per_episode play_context.num_episodes = num_episodes callbacks = self._to_callback_list(callbacks=callbacks) callbacks = self._add_plot_callbacks(callbacks, default_plots, [plot.Steps(), plot.Rewards()]) self._backend_agent.play(play_context=play_context, callbacks=callbacks) return play_context
def test_train_multiple_subplots(self): agent = agents.PpoAgent("CartPole-v0") agent.train([ duration._SingleIteration(), plot.State(), plot.Rewards(), plot.Loss(), plot.Steps() ])
def test_train_plotsteps(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleIteration(), plot.Steps()])