def train(self, callbacks: Union[List[core.AgentCallback], core.AgentCallback, None] = None, num_iterations: int = 10, max_steps_per_episode: int = 1000, num_episodes_per_eval: int = 10, train_context: core.TrainContext = None, default_plots: bool = None): """Evaluates the environment using a uniform random policy. The evaluation is performed in batches of num_episodes_per_eval episodes. Args: callbacks: list of callbacks called during training and evaluation num_iterations: number of times a batch of num_episodes_per_eval episodes is evaluated. max_steps_per_episode: maximum number of steps per episode num_episodes_per_eval: number of episodes played to estimate the average return and steps train_context: training configuration to be used. if set overrides all other training context arguments. default_plots: if set adds a set of default callbacks (plot.State, plot.Rewards, plot.Loss,...) Returns: train_context: the training configuration containing the loss and sum of rewards encountered during training """ if train_context is None: train_context = core.TrainContext() train_context.num_iterations = num_iterations train_context.max_steps_per_episode = max_steps_per_episode train_context.num_epochs_per_iteration = 0 train_context.num_iterations_between_eval = 1 train_context.num_episodes_per_eval = num_episodes_per_eval train_context.learning_rate = 1 super().train(train_context=train_context, callbacks=callbacks, default_plots=default_plots) return train_context
def test_train(self): model_config = core.ModelConfig(_lineworld_name) tc = core.TrainContext() random_agent = tfagents.TfRandomAgent(model_config=model_config) random_agent.train(train_context=tc, callbacks=[duration.Fast(), log.Iteration()]) assert tc.episodes_done_in_iteration == 1
def test_random_train(self): model_config = core.ModelConfig("CartPole-v0") tc = core.TrainContext() randomAgent = tforce.TforceRandomAgent(model_config=model_config) randomAgent.train( train_context=tc, callbacks=[log.Iteration(), log.Agent(), duration.Fast()])
def test_random_train(self): from easyagents.backends import tforce model_config = core.ModelConfig(_cartpole_name) tc = core.TrainContext() tc.num_iterations = 50 random_agent = tforce.TforceRandomAgent(model_config=model_config) random_agent.train(train_context=tc, callbacks=[log.Iteration(), log.Agent()]) (min_r, avg_r, max_r) = tc.eval_rewards[tc.episodes_done_in_training] assert avg_r < 50