Пример #1
0
 def _test_evaluator_ground_truth(
     self,
     dueling=False,
     use_gpu=False,
     use_all_avail_gpus=False,
     clip_grad_norm=None,
     modular=False,
 ):
     environment = Gridworld()
     evaluator = GridworldEvaluator(environment, False, DISCOUNT)
     if modular:
         trainer, exporter = self.get_modular_sarsa_trainer_exporter(
             environment, {}, dueling, use_gpu, use_all_avail_gpus,
             clip_grad_norm)
     else:
         trainer, exporter = self.get_sarsa_trainer_exporter(
             environment, {}, dueling, use_gpu, use_all_avail_gpus,
             clip_grad_norm)
     self.evaluate_gridworld(environment, evaluator, trainer, exporter,
                             use_gpu)
Пример #2
0
 def _test_reward_boost(
     self, use_gpu=False, use_all_avail_gpus=False, modular=False
 ):
     environment = Gridworld()
     reward_boost = {"L": 100, "R": 200, "U": 300, "D": 400}
     if modular:
         trainer, exporter = self.get_modular_sarsa_trainer_exporter(
             environment, reward_boost, False, use_gpu, use_all_avail_gpus
         )
     else:
         trainer, exporter = self.get_sarsa_trainer_exporter(
             environment, reward_boost, False, use_gpu, use_all_avail_gpus
         )
     evaluator = GridworldEvaluator(
         env=environment,
         assume_optimal_policy=False,
         gamma=DISCOUNT,
         use_int_features=False,
     )
     self.evaluate_gridworld(environment, evaluator, trainer, exporter, use_gpu)
Пример #3
0
 def generate_samples(self, num_transitions, epsilon, with_possible=True) -> Samples:
     samples = Gridworld.generate_samples(
         self, num_transitions, epsilon, with_possible
     )
     enum_states = []
     for state in samples.states:
         enum_states.append({0: float(list(state.keys())[0])})
     enum_next_states = []
     for state in samples.next_states:
         enum_next_states.append({0: float(list(state.keys())[0])})
     return Samples(
         states=enum_states,
         actions=samples.actions,
         propensities=samples.propensities,
         rewards=samples.rewards,
         next_states=enum_next_states,
         next_actions=samples.next_actions,
         is_terminal=samples.is_terminal,
         possible_next_actions=samples.possible_next_actions,
         reward_timelines=samples.reward_timelines,
     )
Пример #4
0
 def _test_evaluator_ground_truth(
     self,
     dueling=False,
     categorical=False,
     quantile=False,
     use_gpu=False,
     use_all_avail_gpus=False,
     clip_grad_norm=None,
 ):
     environment = Gridworld()
     evaluator = GridworldEvaluator(environment, False, DISCOUNT)
     trainer = self.get_trainer(
         environment,
         {},
         dueling=dueling,
         categorical=categorical,
         quantile=quantile,
         use_gpu=use_gpu,
         use_all_avail_gpus=use_all_avail_gpus,
         clip_grad_norm=clip_grad_norm,
     )
     self.evaluate_gridworld(environment, evaluator, trainer, use_gpu)
Пример #5
0
 def generate_samples(self, num_transitions, epsilon,
                      discount_factor) -> Samples:
     samples = Gridworld.generate_samples(self, num_transitions, epsilon,
                                          discount_factor)
     enum_states = []
     for state in samples.states:
         enum_states.append({0: float(list(state.keys())[0])})
     enum_next_states = []
     for state in samples.next_states:
         enum_next_states.append({0: float(list(state.keys())[0])})
     return Samples(
         mdp_ids=samples.mdp_ids,
         sequence_numbers=samples.sequence_numbers,
         states=enum_states,
         actions=samples.actions,
         propensities=samples.propensities,
         rewards=samples.rewards,
         possible_actions=samples.possible_actions,
         next_states=enum_next_states,
         next_actions=samples.next_actions,
         terminals=samples.terminals,
         possible_next_actions=samples.possible_next_actions,
         episode_values=samples.episode_values,
     )
Пример #6
0
    def test_predictor_torch_export(self):
        """Verify that q-values before model export equal q-values after
        model export. Meant to catch issues with export logic."""
        environment = Gridworld()
        samples = Samples(
            mdp_ids=["0"],
            sequence_numbers=[0],
            sequence_number_ordinals=[1],
            states=[{
                0: 1.0,
                1: 1.0,
                2: 1.0,
                3: 1.0,
                4: 1.0,
                5: 1.0,
                15: 1.0,
                24: 1.0
            }],
            actions=["D"],
            action_probabilities=[0.5],
            rewards=[0],
            possible_actions=[["R", "D"]],
            next_states=[{
                5: 1.0
            }],
            next_actions=["U"],
            terminals=[False],
            possible_next_actions=[["R", "U", "D"]],
        )
        tdps = environment.preprocess_samples(samples, 1)
        assert len(tdps) == 1, "Invalid number of data pages"

        trainer = self.get_trainer(environment, {}, False, False, False)
        input = rlt.PreprocessedState.from_tensor(tdps[0].states)

        pre_export_q_values = trainer.q_network(
            input).q_values.detach().numpy()

        preprocessor = Preprocessor(environment.normalization, False)
        cpu_q_network = trainer.q_network.cpu_model()
        cpu_q_network.eval()
        dqn_with_preprocessor = DiscreteDqnWithPreprocessor(
            cpu_q_network, preprocessor)
        serving_module = DiscreteDqnPredictorWrapper(
            dqn_with_preprocessor, action_names=environment.ACTIONS)

        with tempfile.TemporaryDirectory() as tmpdirname:
            buf = export_module_to_buffer(serving_module)
            tmp_path = os.path.join(tmpdirname, "model")
            with open(tmp_path, "wb") as f:
                f.write(buf.getvalue())
                f.close()
                predictor = DiscreteDqnTorchPredictor(torch.jit.load(tmp_path))

        post_export_q_values = predictor.predict([samples.states[0]])

        for i, action in enumerate(environment.ACTIONS):
            self.assertAlmostEqual(
                float(pre_export_q_values[0][i]),
                float(post_export_q_values[0][action]),
                places=4,
            )
Пример #7
0
 def true_rewards_for_sample(self, enum_states, actions):
     states = []
     for state in enum_states:
         states.append({int(list(state.values())[0]): 1})
     return Gridworld.true_rewards_for_sample(self, states, actions)
Пример #8
0
    def test_gridworld_generate_samples(self):
        env = Gridworld()
        num_samples = 1000
        num_steps = 5
        samples = env.generate_samples(num_samples,
                                       epsilon=1.0,
                                       discount_factor=0.9,
                                       multi_steps=num_steps)
        for i in range(num_samples):
            if samples.terminals[i][0]:
                break
            if i < num_samples - 1:
                self.assertEqual(samples.mdp_ids[i], samples.mdp_ids[i + 1])
                self.assertEqual(samples.sequence_numbers[i] + 1,
                                 samples.sequence_numbers[i + 1])
            for j in range(len(samples.terminals[i])):
                self.assertEqual(samples.rewards[i][j],
                                 samples.rewards[i + j][0])
                self.assertDictEqual(samples.next_states[i][j],
                                     samples.next_states[i + j][0])
                self.assertEqual(samples.next_actions[i][j],
                                 samples.next_actions[i + j][0])
                self.assertEqual(samples.terminals[i][j],
                                 samples.terminals[i + j][0])
                self.assertListEqual(
                    samples.possible_next_actions[i][j],
                    samples.possible_next_actions[i + j][0],
                )
                if samples.terminals[i][j]:
                    continue
                self.assertDictEqual(samples.next_states[i][j],
                                     samples.states[i + j + 1])
                self.assertEqual(samples.next_actions[i][j],
                                 samples.actions[i + j + 1])
                self.assertListEqual(
                    samples.possible_next_actions[i][j],
                    samples.possible_actions[i + j + 1],
                )

        single_step_samples = samples.to_single_step()
        for i in range(num_samples):
            if single_step_samples.terminals[i] is True:
                break
            self.assertEqual(single_step_samples.mdp_ids[i],
                             samples.mdp_ids[i])
            self.assertEqual(single_step_samples.sequence_numbers[i],
                             samples.sequence_numbers[i])
            self.assertDictEqual(single_step_samples.states[i],
                                 samples.states[i])
            self.assertEqual(single_step_samples.actions[i],
                             samples.actions[i])
            self.assertEqual(
                single_step_samples.action_probabilities[i],
                samples.action_probabilities[i],
            )
            self.assertEqual(single_step_samples.rewards[i],
                             samples.rewards[i][0])
            self.assertListEqual(single_step_samples.possible_actions[i],
                                 samples.possible_actions[i])
            self.assertDictEqual(single_step_samples.next_states[i],
                                 samples.next_states[i][0])
            self.assertEqual(single_step_samples.next_actions[i],
                             samples.next_actions[i][0])
            self.assertEqual(single_step_samples.terminals[i],
                             samples.terminals[i][0])
            self.assertListEqual(
                single_step_samples.possible_next_actions[i],
                samples.possible_next_actions[i][0],
            )
Пример #9
0
 def envs_and_evaluators():
     return [
         (Gridworld(), GridworldEvaluator),
         (GridworldEnum(), GridworldEnumEvaluator),
     ]
Пример #10
0
 def envs():
     return [(Gridworld(), ), (GridworldEnum(), )]