def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        states, actions, rewards, next_states, next_actions, is_terminal,\
            possible_next_actions, reward_timelines = \
            environment.generate_samples(100000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False)
        tdps = environment.preprocess_samples(
            states,
            actions,
            rewards,
            next_states,
            next_actions,
            is_terminal,
            possible_next_actions,
            reward_timelines,
            self.minibatch_size,
        )

        self.assertGreater(evaluator.evaluate(predictor), 0.15)

        for tdp in tdps:
            trainer.train_numpy(tdp, None)
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.05)
Exemple #2
0
    def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(100000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            trainer.train_numpy(tdp, None)
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
Exemple #3
0
    def test_trainer_sarsa_enum_factorized(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(500000, 1.0, DISCOUNT)
        trainer = self.get_sarsa_trainer(
            environment, self.get_sarsa_parameters_factorized())
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            trainer.train(tdp)

        predictor = trainer.predictor()
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
Exemple #4
0
    def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(150000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            tdp.rewards = tdp.rewards.flatten()
            tdp.not_terminals = tdp.not_terminals.flatten()
            trainer.train(tdp)

        predictor = trainer.predictor()
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
 def envs():
     return [
         (GridworldContinuous(), ),
         (GridworldContinuousEnum(), ),
     ]
 def envs_and_evaluators():
     return [
         (GridworldContinuous(), GridworldContinuousEvaluator),
         (GridworldContinuousEnum(), GridworldContinuousEnumEvaluator),
     ]