Ejemplo n.º 1
0
 def test_default_plots_True_plotcallback(self):
     agent = agents.PpoAgent("CartPole-v0")
     p = plot.Loss()
     r = plot.Rewards()
     c = agent._add_plot_callbacks([r], True, [p])
     assert p in c
     assert r in c
Ejemplo n.º 2
0
 def test_default_plots_None_plotcallback(self):
     agent = agents.PpoAgent("CartPole-v0")
     p = plot.Loss()
     r = plot.Rewards()
     c = agent._prepare_callbacks([r], None, [p])
     assert not p in c
     assert r in c
Ejemplo n.º 3
0
    def train(self, train_context: core.TrainContext,
              callbacks: Union[List[core.AgentCallback], core.AgentCallback,
                               None], default_plots: Optional[bool]):
        """Trains a new model using the gym environment passed during instantiation.

        Args:
            callbacks: list of callbacks called during the training and evaluation
            train_context: training configuration to be used (num_iterations,num_episodes_per_iteration,...)
            default_plots: if set adds a set of default callbacks (plot.State, plot.Rewards, plot.Loss,...).
                if None default callbacks are only added if the callbacks list is empty
        """
        assert train_context, "train_context not set."
        if callbacks is None:
            callbacks = []
        if not isinstance(callbacks, list):
            assert isinstance(
                callbacks, core.AgentCallback
            ), "callback not a AgentCallback or a list thereof."
            callbacks = [callbacks]
        callbacks = self._prepare_callbacks(
            callbacks, default_plots,
            [plot.Loss(), plot.Steps(),
             plot.Rewards()])
        self._backend_agent.train(train_context=train_context,
                                  callbacks=callbacks)
Ejemplo n.º 4
0
 def test_train_multiple_subplots(self):
     agent = agents.PpoAgent("CartPole-v0")
     agent.train([
         duration._SingleIteration(),
         plot.State(),
         plot.Rewards(),
         plot.Loss(),
         plot.Steps()
     ])
Ejemplo n.º 5
0
 def test_train_plotloss(self):
     agent = agents.PpoAgent("CartPole-v0")
     agent.train([duration._SingleIteration(), plot.Loss()])
Ejemplo n.º 6
0
 def test_default_plots_None_nocallback(self):
     agent = agents.PpoAgent("CartPole-v0")
     p = plot.Loss()
     c = agent._add_plot_callbacks([], None, [p])
     assert p in c
Ejemplo n.º 7
0
 def test_default_plots_None_durationcallback(self):
     agent = agents.PpoAgent("CartPole-v0")
     p = plot.Loss()
     c = agent._prepare_callbacks([duration.Fast()], None, [p])
     assert p in c
Ejemplo n.º 8
0
 def test_default_plots_False_nocallback(self):
     agent = agents.PpoAgent("CartPole-v0")
     p = plot.Loss()
     c = agent._prepare_callbacks([], False, [p])
     assert not p in c