Пример #1
0
 def true_values_for_sample(self, enum_states, actions,
                            assume_optimal_policy: bool):
     states = []
     for state in enum_states:
         states.append({int(list(state.values())[0]): 1})
     return Gridworld.true_values_for_sample(self, states, actions,
                                             assume_optimal_policy)
Пример #2
0
    def test_evaluator_ground_truth(self):
        environment = Gridworld()
        states, actions, rewards, next_states, next_actions, is_terminal,\
            possible_next_actions, _ = environment.generate_samples(100000, 1.0)
        true_values = environment.true_values_for_sample(
            states, actions, False)
        # Hijack the reward timeline to insert the ground truth
        reward_timelines = []
        for tv in true_values:
            reward_timelines.append({0: tv})
        trainer = self.get_sarsa_trainer(environment)
        evaluator = Evaluator(trainer, DISCOUNT)
        tdps = environment.preprocess_samples(
            states,
            actions,
            rewards,
            next_states,
            next_actions,
            is_terminal,
            possible_next_actions,
            reward_timelines,
            self.minibatch_size,
        )

        for tdp in tdps:
            trainer.stream_tdp(tdp, evaluator)

        self.assertLess(evaluator.td_loss[-1], 0.05)
        self.assertLess(evaluator.mc_loss[-1], 0.05)
Пример #3
0
    def test_evaluator_ground_truth_no_dueling(self):
        environment = Gridworld()
        samples = environment.generate_samples(500000, 1.0, DISCOUNT)
        true_values = environment.true_values_for_sample(
            samples.states, samples.actions, False)
        # Hijack the reward timeline to insert the ground truth
        samples.episode_values = true_values
        trainer = self.get_sarsa_trainer(environment, False)
        evaluator = Evaluator(environment.ACTIONS, 10, DISCOUNT, None, None)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for tdp in tdps:
            trainer.train(tdp, evaluator)

        self.assertLess(evaluator.mc_loss[-1], 0.1)
Пример #4
0
    def test_evaluator_ground_truth(self):
        environment = Gridworld()
        samples = environment.generate_samples(200000, 1.0)
        true_values = environment.true_values_for_sample(
            samples.states, samples.actions, False)
        # Hijack the reward timeline to insert the ground truth
        samples.reward_timelines = []
        for tv in true_values:
            samples.reward_timelines.append({0: tv})
        trainer = self.get_sarsa_trainer(environment)
        evaluator = Evaluator(environment.ACTIONS, 10, DISCOUNT, None, None)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for _ in range(2):
            for tdp in tdps:
                trainer.train_numpy(tdp, evaluator)

        self.assertLess(evaluator.mc_loss[-1], 0.1)