示例#1
0
 def test_open_ai_gym_generate_samples_multi_step(self):
     env = OpenAIGymEnvironment(
         "CartPole-v0",
         epsilon=1.0,  # take random actions to collect training data
         softmax_policy=False,
         gamma=0.9,
     )
     num_samples = 1000
     num_steps = 5
     samples = env.generate_random_samples(
         num_samples, use_continuous_action=True, epsilon=1.0, multi_steps=num_steps
     )
     self._check_samples(samples, num_samples, num_steps, True)
示例#2
0
def multi_step_sample_generator(
    gym_env: OpenAIGymEnvironment,
    num_transitions: int,
    max_steps: Optional[int],
    multi_steps: int,
    include_shorter_samples_at_start: bool,
    include_shorter_samples_at_end: bool,
):
    """
    Convert gym env multi-step sample format to mdn-rnn multi-step sample format

    :param gym_env: The environment used to generate multi-step samples
    :param num_transitions: # of samples to return
    :param max_steps: An episode terminates when the horizon is beyond max_steps
    :param multi_steps: # of steps of states and actions per sample
    :param include_shorter_samples_at_start: Whether to keep samples of shorter steps
        which are generated at the beginning of an episode
    :param include_shorter_samples_at_end: Whether to keep samples of shorter steps
        which are generated at the end of an episode
    """
    samples = gym_env.generate_random_samples(
        num_transitions=num_transitions,
        use_continuous_action=True,
        max_step=max_steps,
        multi_steps=multi_steps,
        include_shorter_samples_at_start=include_shorter_samples_at_start,
        include_shorter_samples_at_end=include_shorter_samples_at_end,
    )

    for j in range(num_transitions):
        sample_steps = len(samples.terminals[j])  # type: ignore
        state = dict_to_np(samples.states[j],
                           np_size=gym_env.state_dim,
                           key_offset=0)
        action = dict_to_np(samples.actions[j],
                            np_size=gym_env.action_dim,
                            key_offset=gym_env.state_dim)
        next_actions = np.float32(  # type: ignore
            [
                dict_to_np(
                    samples.next_actions[j][k],
                    np_size=gym_env.action_dim,
                    key_offset=gym_env.state_dim,
                ) for k in range(sample_steps)
            ])
        next_states = np.float32(  # type: ignore
            [
                dict_to_np(samples.next_states[j][k],
                           np_size=gym_env.state_dim,
                           key_offset=0) for k in range(sample_steps)
            ])
        rewards = np.float32(samples.rewards[j])  # type: ignore
        terminals = np.float32(samples.terminals[j])  # type: ignore
        not_terminals = np.logical_not(terminals)
        ordered_states = np.vstack((state, next_states))
        ordered_actions = np.vstack((action, next_actions))
        mdnrnn_states = ordered_states[:-1]
        mdnrnn_actions = ordered_actions[:-1]
        mdnrnn_next_states = ordered_states[-multi_steps:]
        mdnrnn_next_actions = ordered_actions[-multi_steps:]

        # Padding zeros so that all samples have equal steps
        # The general rule is to pad zeros at the end of sequences.
        # In addition, if the sequence only has one step (i.e., the
        # first state of an episode), pad one zero row ahead of the
        # sequence, which enables embedding generated properly for
        # one-step samples
        num_padded_top_rows = 1 if multi_steps > 1 and sample_steps == 1 else 0
        num_padded_bottom_rows = multi_steps - sample_steps - num_padded_top_rows
        sample_steps_next = len(mdnrnn_next_states)
        num_padded_top_rows_next = 0
        num_padded_bottom_rows_next = multi_steps - sample_steps_next
        yield (
            np.pad(
                mdnrnn_states,
                ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_actions,
                ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                rewards,
                ((num_padded_top_rows, num_padded_bottom_rows)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_next_states,
                ((num_padded_top_rows_next, num_padded_bottom_rows_next),
                 (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_next_actions,
                ((num_padded_top_rows_next, num_padded_bottom_rows_next),
                 (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                not_terminals,
                ((num_padded_top_rows, num_padded_bottom_rows)),
                "constant",
                constant_values=0.0,
            ),
            sample_steps,
            sample_steps_next,
        )
示例#3
0
    def test_open_ai_gym_generate_samples_multi_step(self):
        env = OpenAIGymEnvironment(
            "CartPole-v0",
            epsilon=1.0,  # take random actions to collect training data
            softmax_policy=False,
            gamma=0.9,
        )
        num_samples = 1000
        num_steps = 5
        samples = env.generate_random_samples(num_samples,
                                              use_continuous_action=True,
                                              epsilon=1.0,
                                              multi_steps=num_steps)
        for i in range(num_samples):
            if samples.terminals[i][0]:
                break
            if i < num_samples - 1:
                self.assertEqual(samples.mdp_ids[i], samples.mdp_ids[i + 1])
                self.assertEqual(samples.sequence_numbers[i] + 1,
                                 samples.sequence_numbers[i + 1])
            for j in range(len(samples.terminals[i])):
                self.assertEqual(samples.rewards[i][j],
                                 samples.rewards[i + j][0])
                self.assertDictEqual(samples.next_states[i][j],
                                     samples.next_states[i + j][0])
                self.assertDictEqual(samples.next_actions[i][j],
                                     samples.next_actions[i + j][0])
                self.assertEqual(samples.terminals[i][j],
                                 samples.terminals[i + j][0])
                self.assertListEqual(
                    samples.possible_next_actions[i][j],
                    samples.possible_next_actions[i + j][0],
                )
                if samples.terminals[i][j]:
                    continue
                self.assertDictEqual(samples.next_states[i][j],
                                     samples.states[i + j + 1])
                self.assertDictEqual(samples.next_actions[i][j],
                                     samples.actions[i + j + 1])
                self.assertListEqual(
                    samples.possible_next_actions[i][j],
                    samples.possible_actions[i + j + 1],
                )

        single_step_samples = samples.to_single_step()
        for i in range(num_samples):
            if single_step_samples.terminals[i] is True:
                break
            self.assertEqual(single_step_samples.mdp_ids[i],
                             samples.mdp_ids[i])
            self.assertEqual(single_step_samples.sequence_numbers[i],
                             samples.sequence_numbers[i])
            self.assertDictEqual(single_step_samples.states[i],
                                 samples.states[i])
            self.assertDictEqual(single_step_samples.actions[i],
                                 samples.actions[i])
            self.assertEqual(
                single_step_samples.action_probabilities[i],
                samples.action_probabilities[i],
            )
            self.assertEqual(single_step_samples.rewards[i],
                             samples.rewards[i][0])
            self.assertListEqual(single_step_samples.possible_actions[i],
                                 samples.possible_actions[i])
            self.assertDictEqual(single_step_samples.next_states[i],
                                 samples.next_states[i][0])
            self.assertDictEqual(single_step_samples.next_actions[i],
                                 samples.next_actions[i][0])
            self.assertEqual(single_step_samples.terminals[i],
                             samples.terminals[i][0])
            self.assertListEqual(
                single_step_samples.possible_next_actions[i],
                samples.possible_next_actions[i][0],
            )