예제 #1
0
def multi_step_sample_generator(
    gym_env: OpenAIGymEnvironment,
    num_transitions: int,
    max_steps: Optional[int],
    multi_steps: int,
    include_shorter_samples_at_start: bool,
    include_shorter_samples_at_end: bool,
):
    """
    Convert gym env multi-step sample format to mdn-rnn multi-step sample format

    :param gym_env: The environment used to generate multi-step samples
    :param num_transitions: # of samples to return
    :param max_steps: An episode terminates when the horizon is beyond max_steps
    :param multi_steps: # of steps of states and actions per sample
    :param include_shorter_samples_at_start: Whether to keep samples of shorter steps
        which are generated at the beginning of an episode
    :param include_shorter_samples_at_end: Whether to keep samples of shorter steps
        which are generated at the end of an episode
    """
    samples = gym_env.generate_random_samples(
        num_transitions=num_transitions,
        use_continuous_action=True,
        max_step=max_steps,
        multi_steps=multi_steps,
        include_shorter_samples_at_start=include_shorter_samples_at_start,
        include_shorter_samples_at_end=include_shorter_samples_at_end,
    )

    for j in range(num_transitions):
        sample_steps = len(samples.terminals[j])  # type: ignore
        state = dict_to_np(samples.states[j],
                           np_size=gym_env.state_dim,
                           key_offset=0)
        action = dict_to_np(samples.actions[j],
                            np_size=gym_env.action_dim,
                            key_offset=gym_env.state_dim)
        next_actions = np.float32(  # type: ignore
            [
                dict_to_np(
                    samples.next_actions[j][k],
                    np_size=gym_env.action_dim,
                    key_offset=gym_env.state_dim,
                ) for k in range(sample_steps)
            ])
        next_states = np.float32(  # type: ignore
            [
                dict_to_np(samples.next_states[j][k],
                           np_size=gym_env.state_dim,
                           key_offset=0) for k in range(sample_steps)
            ])
        rewards = np.float32(samples.rewards[j])  # type: ignore
        terminals = np.float32(samples.terminals[j])  # type: ignore
        not_terminals = np.logical_not(terminals)
        ordered_states = np.vstack((state, next_states))
        ordered_actions = np.vstack((action, next_actions))
        mdnrnn_states = ordered_states[:-1]
        mdnrnn_actions = ordered_actions[:-1]
        mdnrnn_next_states = ordered_states[-multi_steps:]
        mdnrnn_next_actions = ordered_actions[-multi_steps:]

        # Padding zeros so that all samples have equal steps
        # The general rule is to pad zeros at the end of sequences.
        # In addition, if the sequence only has one step (i.e., the
        # first state of an episode), pad one zero row ahead of the
        # sequence, which enables embedding generated properly for
        # one-step samples
        num_padded_top_rows = 1 if multi_steps > 1 and sample_steps == 1 else 0
        num_padded_bottom_rows = multi_steps - sample_steps - num_padded_top_rows
        sample_steps_next = len(mdnrnn_next_states)
        num_padded_top_rows_next = 0
        num_padded_bottom_rows_next = multi_steps - sample_steps_next
        yield (
            np.pad(
                mdnrnn_states,
                ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_actions,
                ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                rewards,
                ((num_padded_top_rows, num_padded_bottom_rows)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_next_states,
                ((num_padded_top_rows_next, num_padded_bottom_rows_next),
                 (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                mdnrnn_next_actions,
                ((num_padded_top_rows_next, num_padded_bottom_rows_next),
                 (0, 0)),
                "constant",
                constant_values=0.0,
            ),
            np.pad(
                not_terminals,
                ((num_padded_top_rows, num_padded_bottom_rows)),
                "constant",
                constant_values=0.0,
            ),
            sample_steps,
            sample_steps_next,
        )
예제 #2
0
def get_replay_buffer(num_episodes, seq_len, max_step, gym_env):
    num_transitions = num_episodes * max_step
    samples = gym_env.generate_random_samples(
        num_transitions=num_transitions,
        use_continuous_action=True,
        max_step=max_step,
        multi_steps=seq_len,
    )

    replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions)
    # convert RL sample format to MDN-RNN sample format
    transition_terminal_index = [-1]
    for i in range(1, len(samples.mdp_ids)):
        if samples.terminals[i][0] is True:
            assert len(samples.terminals[i]) == 1
            transition_terminal_index.append(i)

    for i in range(len(transition_terminal_index) - 1):
        episode_start = transition_terminal_index[i] + 1
        episode_end = transition_terminal_index[i + 1]

        for j in range(episode_start, episode_end + 1):
            if len(samples.terminals[j]) != seq_len:
                continue
            state = dict_to_np(samples.states[j],
                               np_size=gym_env.state_dim,
                               key_offset=0)
            action = dict_to_np(
                samples.actions[j],
                np_size=gym_env.action_dim,
                key_offset=gym_env.state_dim,
            )
            next_actions = np.float32([
                dict_to_np(
                    samples.next_actions[j][k],
                    np_size=gym_env.action_dim,
                    key_offset=gym_env.state_dim,
                ) for k in range(seq_len)
            ])
            next_states = np.float32([
                dict_to_np(
                    samples.next_states[j][k],
                    np_size=gym_env.state_dim,
                    key_offset=0,
                ) for k in range(seq_len)
            ])
            rewards = np.float32(samples.rewards[j])
            terminals = np.float32(samples.terminals[j])
            not_terminals = np.logical_not(terminals)
            mdnrnn_state = np.vstack((state, next_states))[:-1]
            mdnrnn_action = np.vstack((action, next_actions))[:-1]

            assert mdnrnn_state.shape == (seq_len, gym_env.state_dim)
            assert mdnrnn_action.shape == (seq_len, gym_env.action_dim)
            assert rewards.shape == (seq_len, )
            assert next_states.shape == (seq_len, gym_env.state_dim)
            assert not_terminals.shape == (seq_len, )

            replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action,
                                             next_states, rewards,
                                             not_terminals)

    return replay_buffer