def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        states, actions, rewards, next_states, next_actions, is_terminal,\
            possible_next_actions, reward_timelines = \
            environment.generate_samples(100000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False)
        tdps = environment.preprocess_samples(
            states,
            actions,
            rewards,
            next_states,
            next_actions,
            is_terminal,
            possible_next_actions,
            reward_timelines,
            self.minibatch_size,
        )

        self.assertGreater(evaluator.evaluate(predictor), 0.15)

        for tdp in tdps:
            trainer.train_numpy(tdp, None)
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.05)
    def test_trainer_single_batch_maxq(self, environment):
        rl_parameters = self.get_sarsa_parameters()
        new_rl_parameters = ContinuousActionModelParameters(
            rl=RLParameters(
                gamma=DISCOUNT,
                target_update_rate=0.5,
                reward_burnin=10,
                maxq_learning=True,
            ),
            training=rl_parameters.training,
            knn=rl_parameters.knn)
        maxq_trainer = ContinuousActionDQNTrainer(
            environment.normalization, environment.normalization_action,
            new_rl_parameters)

        states, actions, rewards, next_states, next_actions, is_terminal,\
            possible_next_actions, reward_timelines = \
            environment.generate_samples(100000, 1.0)
        predictor = maxq_trainer.predictor()
        tbp = environment.preprocess_samples(states, actions, rewards,
                                             next_states, next_actions,
                                             is_terminal,
                                             possible_next_actions,
                                             reward_timelines)
        evaluator = GridworldContinuousEvaluator(environment, True)
        self.assertGreater(evaluator.evaluate(predictor), 0.4)

        for _ in range(2):
            maxq_trainer.stream_tdp(tbp)
            evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.1)
    def test_trainer_maxq(self):
        environment = GridworldContinuous()
        rl_parameters = self.get_sarsa_parameters()
        new_rl_parameters = ContinuousActionModelParameters(
            rl=RLParameters(
                gamma=DISCOUNT,
                target_update_rate=0.5,
                reward_burnin=10,
                maxq_learning=True,
            ),
            training=rl_parameters.training,
            knn=rl_parameters.knn,
        )
        maxq_trainer = ContinuousActionDQNTrainer(
            new_rl_parameters,
            environment.normalization,
            environment.normalization_action,
        )

        samples = environment.generate_samples(100000, 1.0)
        predictor = maxq_trainer.predictor()
        tdps = environment.preprocess_samples(samples, self.minibatch_size)
        evaluator = GridworldContinuousEvaluator(environment, True)
        self.assertGreater(evaluator.evaluate(predictor), 0.2)

        for _ in range(2):
            for tdp in tdps:
                maxq_trainer.train_numpy(tdp, None)
            evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
예제 #4
0
    def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(100000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            trainer.train_numpy(tdp, None)
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
예제 #5
0
    def test_trainer_sarsa_enum_factorized(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(500000, 1.0, DISCOUNT)
        trainer = self.get_sarsa_trainer(
            environment, self.get_sarsa_parameters_factorized())
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            trainer.train(tdp)

        predictor = trainer.predictor()
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
예제 #6
0
    def test_trainer_sarsa_enum(self):
        environment = GridworldContinuousEnum()
        samples = environment.generate_samples(150000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT,
                                                 False, samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            tdp.rewards = tdp.rewards.flatten()
            tdp.not_terminals = tdp.not_terminals.flatten()
            trainer.train(tdp)

        predictor = trainer.predictor()
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)
    def test_trainer_single_batch_sarsa(self, environment):
        states, actions, rewards, next_states, next_actions, is_terminal,\
            possible_next_actions, reward_timelines = \
            environment.generate_samples(100000, 1.0)
        trainer = self.get_sarsa_trainer(environment)
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(environment, False)
        tdp = environment.preprocess_samples(states, actions, rewards,
                                             next_states, next_actions,
                                             is_terminal,
                                             possible_next_actions,
                                             reward_timelines)

        self.assertGreater(evaluator.evaluate(predictor), 0.15)

        trainer.stream_tdp(tdp)
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.05)