Beispiel #1
0
    def test_create_df_from_replay_buffer(self):
        env_name = "MiniGrid-Empty-5x5-v0"
        env = Gym(env_name=env_name)
        state_dim = env.observation_space.shape[0]
        # Wrap env in TestEnv
        env = TestEnv(env)
        problem_domain = ProblemDomain.DISCRETE_ACTION
        DATASET_SIZE = 1000
        multi_steps = None
        DS = "2021-09-16"

        # Generate data
        df = create_df_from_replay_buffer(
            env=env,
            problem_domain=problem_domain,
            desired_size=DATASET_SIZE,
            multi_steps=multi_steps,
            ds=DS,
            shuffle_df=False,
        )
        self.assertEqual(len(df), DATASET_SIZE)

        # Check data
        preprocessor = PythonSparseToDenseProcessor(list(range(state_dim)))
        for idx, row in df.iterrows():
            df_mdp_id = row["mdp_id"]
            env_mdp_id = str(env.sart[idx][0])
            self.assertEqual(df_mdp_id, env_mdp_id)

            df_seq_num = row["sequence_number"]
            env_seq_num = env.sart[idx][1]
            self.assertEqual(df_seq_num, env_seq_num)

            df_state = preprocessor.process([row["state_features"]
                                             ])[0][0].numpy()
            env_state = env.sart[idx][2]
            npt.assert_array_equal(df_state, env_state)

            df_action = row["action"]
            env_action = str(env.sart[idx][3])
            self.assertEqual(df_action, env_action)

            df_terminal = row["next_action"] == ""
            env_terminal = env.sart[idx][5]
            self.assertEqual(df_terminal, env_terminal)
            if not df_terminal:
                df_reward = float(row["reward"])
                env_reward = float(env.sart[idx][4])
                npt.assert_allclose(df_reward, env_reward)

                df_next_state = preprocessor.process(
                    [row["next_state_features"]])[0][0].numpy()
                env_next_state = env.sart[idx + 1][2]
                npt.assert_array_equal(df_next_state, env_next_state)

                df_next_action = row["next_action"]
                env_next_action = str(env.sart[idx + 1][3])
                self.assertEqual(df_next_action, env_next_action)
            else:
                del env.sart[idx + 1]
Beispiel #2
0
 def test_int_key_sparse_to_dense(self):
     # int keys, set_missing_value_to_zero=False
     processor = PythonSparseToDenseProcessor(
         self.sorted_features, set_missing_value_to_zero=False)
     value, presence = processor.process(self.int_keyed_sparse_data)
     assert torch.allclose(value, self.expected_value_missing)
     assert torch.all(presence == self.expected_presence_missing)