def test_trainer_sarsa_enum(self): environment = GridworldEnum() samples = environment.generate_samples(100000, 1.0) evaluator = GridworldEvaluator(environment, False) trainer = self.get_sarsa_trainer(environment) predictor = trainer.predictor() tdps = environment.preprocess_samples(samples, self.minibatch_size) evaluator.evaluate(predictor) print( "Pre-Training eval: ", evaluator.mc_loss[-1], evaluator.reward_doubly_robust[-1], ) self.assertGreater(evaluator.mc_loss[-1], 0.15) for tdp in tdps: trainer.train_numpy(tdp, None) evaluator.evaluate(predictor) print( "Post-Training eval: ", evaluator.mc_loss[-1], evaluator.reward_doubly_robust[-1], ) self.assertLess(evaluator.mc_loss[-1], 0.05) self.assertGreater( evaluator.reward_doubly_robust[-1], evaluator.reward_doubly_robust[-2] )
def test_trainer_sarsa_enum(self): environment = GridworldEnum() states, actions, rewards, next_states, next_actions, is_terminal,\ possible_next_actions, reward_timelines = \ environment.generate_samples(100000, 1.0) evaluator = GridworldEvaluator(environment, False) trainer = self.get_sarsa_trainer(environment) predictor = trainer.predictor() tdps = environment.preprocess_samples( states, actions, rewards, next_states, next_actions, is_terminal, possible_next_actions, reward_timelines, self.minibatch_size, ) self.assertGreater(evaluator.evaluate(predictor), 0.15) for tdp in tdps: trainer.stream_tdp(tdp, None) evaluator.evaluate(predictor) self.assertLess(evaluator.evaluate(predictor), 0.05)
def test_trainer_sarsa_enum(self): environment = GridworldEnum() samples = environment.generate_samples(500000, 1.0) evaluator = GridworldEvaluator(environment, False, DISCOUNT, False, samples) trainer = self.get_sarsa_trainer(environment) predictor = trainer.predictor() tdps = environment.preprocess_samples(samples, self.minibatch_size) evaluator.evaluate(predictor) print( "Pre-Training eval: ", evaluator.mc_loss[-1], evaluator.value_doubly_robust[-1], ) self.assertGreater(evaluator.mc_loss[-1], 0.12) for tdp in tdps: tdp.rewards = tdp.rewards.flatten() tdp.not_terminals = tdp.not_terminals.flatten() trainer.train(tdp) predictor = trainer.predictor() evaluator.evaluate(predictor) print( "Post-Training eval: ", evaluator.mc_loss[-1], evaluator.value_doubly_robust[-1], ) self.assertLess(evaluator.mc_loss[-1], 0.1)
def _test_trainer_sarsa_enum(self, use_gpu=False, use_all_avail_gpus=False): environment = GridworldEnum() samples = environment.generate_samples(100000, 1.0, DISCOUNT) evaluator = GridworldEvaluator(environment, False, DISCOUNT, False, samples) trainer = self.get_sarsa_trainer(environment, False, use_gpu=use_gpu, use_all_avail_gpus=use_all_avail_gpus) predictor = trainer.predictor() tdps = environment.preprocess_samples(samples, self.minibatch_size, use_gpu=use_gpu) evaluator.evaluate(predictor) print( "Pre-Training eval: ", evaluator.mc_loss[-1], evaluator.value_doubly_robust[-1], ) self.assertGreater(evaluator.mc_loss[-1], 0.12) for tdp in tdps: trainer.train(tdp) predictor = trainer.predictor() evaluator.evaluate(predictor) print( "Post-Training eval: ", evaluator.mc_loss[-1], evaluator.value_doubly_robust[-1], ) self.assertLess(evaluator.mc_loss[-1], 0.1)
def envs_and_evaluators(): return [ (Gridworld(), GridworldEvaluator), (GridworldEnum(), GridworldEnumEvaluator), ]
def envs(): return [(Gridworld(), ), (GridworldEnum(), )]