def multi_step_sample_generator( gym_env: OpenAIGymEnvironment, num_transitions: int, max_steps: Optional[int], multi_steps: int, include_shorter_samples_at_start: bool, include_shorter_samples_at_end: bool, ): """ Convert gym env multi-step sample format to mdn-rnn multi-step sample format :param gym_env: The environment used to generate multi-step samples :param num_transitions: # of samples to return :param max_steps: An episode terminates when the horizon is beyond max_steps :param multi_steps: # of steps of states and actions per sample :param include_shorter_samples_at_start: Whether to keep samples of shorter steps which are generated at the beginning of an episode :param include_shorter_samples_at_end: Whether to keep samples of shorter steps which are generated at the end of an episode """ samples = gym_env.generate_random_samples( num_transitions=num_transitions, use_continuous_action=True, max_step=max_steps, multi_steps=multi_steps, include_shorter_samples_at_start=include_shorter_samples_at_start, include_shorter_samples_at_end=include_shorter_samples_at_end, ) for j in range(num_transitions): sample_steps = len(samples.terminals[j]) # type: ignore state = dict_to_np(samples.states[j], np_size=gym_env.state_dim, key_offset=0) action = dict_to_np(samples.actions[j], np_size=gym_env.action_dim, key_offset=gym_env.state_dim) next_actions = np.float32( # type: ignore [ dict_to_np( samples.next_actions[j][k], np_size=gym_env.action_dim, key_offset=gym_env.state_dim, ) for k in range(sample_steps) ]) next_states = np.float32( # type: ignore [ dict_to_np(samples.next_states[j][k], np_size=gym_env.state_dim, key_offset=0) for k in range(sample_steps) ]) rewards = np.float32(samples.rewards[j]) # type: ignore terminals = np.float32(samples.terminals[j]) # type: ignore not_terminals = np.logical_not(terminals) ordered_states = np.vstack((state, next_states)) ordered_actions = np.vstack((action, next_actions)) mdnrnn_states = ordered_states[:-1] mdnrnn_actions = ordered_actions[:-1] mdnrnn_next_states = ordered_states[-multi_steps:] mdnrnn_next_actions = ordered_actions[-multi_steps:] # Padding zeros so that all samples have equal steps # The general rule is to pad zeros at the end of sequences. # In addition, if the sequence only has one step (i.e., the # first state of an episode), pad one zero row ahead of the # sequence, which enables embedding generated properly for # one-step samples num_padded_top_rows = 1 if multi_steps > 1 and sample_steps == 1 else 0 num_padded_bottom_rows = multi_steps - sample_steps - num_padded_top_rows sample_steps_next = len(mdnrnn_next_states) num_padded_top_rows_next = 0 num_padded_bottom_rows_next = multi_steps - sample_steps_next yield ( np.pad( mdnrnn_states, ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)), "constant", constant_values=0.0, ), np.pad( mdnrnn_actions, ((num_padded_top_rows, num_padded_bottom_rows), (0, 0)), "constant", constant_values=0.0, ), np.pad( rewards, ((num_padded_top_rows, num_padded_bottom_rows)), "constant", constant_values=0.0, ), np.pad( mdnrnn_next_states, ((num_padded_top_rows_next, num_padded_bottom_rows_next), (0, 0)), "constant", constant_values=0.0, ), np.pad( mdnrnn_next_actions, ((num_padded_top_rows_next, num_padded_bottom_rows_next), (0, 0)), "constant", constant_values=0.0, ), np.pad( not_terminals, ((num_padded_top_rows, num_padded_bottom_rows)), "constant", constant_values=0.0, ), sample_steps, sample_steps_next, )
def get_replay_buffer(num_episodes, seq_len, max_step, gym_env): num_transitions = num_episodes * max_step samples = gym_env.generate_random_samples( num_transitions=num_transitions, use_continuous_action=True, max_step=max_step, multi_steps=seq_len, ) replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions) # convert RL sample format to MDN-RNN sample format transition_terminal_index = [-1] for i in range(1, len(samples.mdp_ids)): if samples.terminals[i][0] is True: assert len(samples.terminals[i]) == 1 transition_terminal_index.append(i) for i in range(len(transition_terminal_index) - 1): episode_start = transition_terminal_index[i] + 1 episode_end = transition_terminal_index[i + 1] for j in range(episode_start, episode_end + 1): if len(samples.terminals[j]) != seq_len: continue state = dict_to_np(samples.states[j], np_size=gym_env.state_dim, key_offset=0) action = dict_to_np( samples.actions[j], np_size=gym_env.action_dim, key_offset=gym_env.state_dim, ) next_actions = np.float32([ dict_to_np( samples.next_actions[j][k], np_size=gym_env.action_dim, key_offset=gym_env.state_dim, ) for k in range(seq_len) ]) next_states = np.float32([ dict_to_np( samples.next_states[j][k], np_size=gym_env.state_dim, key_offset=0, ) for k in range(seq_len) ]) rewards = np.float32(samples.rewards[j]) terminals = np.float32(samples.terminals[j]) not_terminals = np.logical_not(terminals) mdnrnn_state = np.vstack((state, next_states))[:-1] mdnrnn_action = np.vstack((action, next_actions))[:-1] assert mdnrnn_state.shape == (seq_len, gym_env.state_dim) assert mdnrnn_action.shape == (seq_len, gym_env.action_dim) assert rewards.shape == (seq_len, ) assert next_states.shape == (seq_len, gym_env.state_dim) assert not_terminals.shape == (seq_len, ) replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals) return replay_buffer