def get_replay_buffer(num_episodes: int, seq_len: int, max_step: int, gym_env: OpenAIGymEnvironment) -> MDNRNNMemoryPool: num_transitions = num_episodes * max_step replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions) for ( mdnrnn_state, mdnrnn_action, rewards, next_states, _, not_terminals, _, _, ) in multi_step_sample_generator( gym_env, num_transitions=num_transitions, max_steps=max_step, multi_steps=seq_len, include_shorter_samples_at_start=False, include_shorter_samples_at_end=False, ): mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals = ( torch.tensor(mdnrnn_state), torch.tensor(mdnrnn_action), torch.tensor(next_states), torch.tensor(rewards), torch.tensor(not_terminals), ) replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals) return replay_buffer
def get_replay_buffer( num_episodes: int, seq_len: int, max_step: Optional[int], gym_env: OpenAIGymEnvironment, ): num_transitions = num_episodes * max_step # type: ignore replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions) for ( mdnrnn_state, mdnrnn_action, rewards, next_states, _, not_terminals, _, _, ) in multi_step_sample_generator( gym_env, num_transitions=num_transitions, max_steps=max_step, multi_steps=seq_len, ignore_shorter_samples_at_start=True, ignore_shorter_samples_at_end=True, ): replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals) return replay_buffer
def test_mdnrnn_simulate_world(self): num_epochs = 300 num_episodes = 400 batch_size = 200 action_dim = 2 seq_len = 5 state_dim = 2 simulated_num_gaussians = 2 mdrnn_num_gaussians = 2 simulated_num_hidden_layers = 1 simulated_num_hiddens = 3 mdnrnn_num_hidden_layers = 1 mdnrnn_num_hiddens = 10 adam_lr = 0.01 replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes) swm = SimulatedWorldModel( action_dim=action_dim, state_dim=state_dim, num_gaussians=simulated_num_gaussians, lstm_num_hidden_layers=simulated_num_hidden_layers, lstm_num_hiddens=simulated_num_hiddens, ) possible_actions = torch.eye(action_dim) for _ in range(num_episodes): cur_state_mem = np.zeros((seq_len, state_dim)) next_state_mem = np.zeros((seq_len, state_dim)) action_mem = np.zeros((seq_len, action_dim)) reward_mem = np.zeros(seq_len) not_terminal_mem = np.zeros(seq_len) next_mus_mem = np.zeros( (seq_len, simulated_num_gaussians, state_dim)) swm.init_hidden(batch_size=1) next_state = torch.randn((1, 1, state_dim)) for s in range(seq_len): cur_state = next_state action = possible_actions[np.random.randint(action_dim)].view( 1, 1, action_dim) next_mus, reward = swm(action, cur_state) not_terminal = 1 if s == seq_len - 1: not_terminal = 0 # randomly draw for next state next_pi = torch.ones( simulated_num_gaussians) / simulated_num_gaussians index = Categorical(next_pi).sample((1, )).long().item() next_state = next_mus[0, 0, index].view(1, 1, state_dim) cur_state_mem[s] = cur_state.detach().numpy() action_mem[s] = action.numpy() reward_mem[s] = reward.detach().numpy() not_terminal_mem[s] = not_terminal next_state_mem[s] = next_state.detach().numpy() next_mus_mem[s] = next_mus.detach().numpy() replay_buffer.insert_into_memory(cur_state_mem, action_mem, next_state_mem, reward_mem, not_terminal_mem) num_batch = num_episodes // batch_size mdnrnn_params = MDNRNNParameters( hidden_size=mdnrnn_num_hiddens, num_hidden_layers=mdnrnn_num_hidden_layers, minibatch_size=batch_size, learning_rate=adam_lr, num_gaussians=mdrnn_num_gaussians, ) mdnrnn_net = MemoryNetwork( state_dim=state_dim, action_dim=action_dim, num_hiddens=mdnrnn_params.hidden_size, num_hidden_layers=mdnrnn_params.num_hidden_layers, num_gaussians=mdnrnn_params.num_gaussians, ) trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net, params=mdnrnn_params, cum_loss_hist=num_batch) for e in range(num_epochs): for i in range(num_batch): training_batch = replay_buffer.sample_memories(batch_size) losses = trainer.train(training_batch) logger.info( "{}-th epoch, {}-th minibatch: \n" "loss={}, bce={}, gmm={}, mse={} \n" "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format( e, i, losses["loss"], losses["bce"], losses["gmm"], losses["mse"], np.mean(trainer.cum_loss), np.mean(trainer.cum_bce), np.mean(trainer.cum_gmm), np.mean(trainer.cum_mse), )) if (np.mean(trainer.cum_loss) < 0 and np.mean(trainer.cum_gmm) < -3.0 and np.mean(trainer.cum_bce) < 0.6 and np.mean(trainer.cum_mse) < 0.2): return assert False, "losses not reduced significantly during training"
def get_replay_buffer(num_episodes, seq_len, max_step, gym_env): num_transitions = num_episodes * max_step samples = gym_env.generate_random_samples( num_transitions=num_transitions, use_continuous_action=True, max_step=max_step, multi_steps=seq_len, ) replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions) # convert RL sample format to MDN-RNN sample format transition_terminal_index = [-1] for i in range(1, len(samples.mdp_ids)): if samples.terminals[i][0] is True: assert len(samples.terminals[i]) == 1 transition_terminal_index.append(i) for i in range(len(transition_terminal_index) - 1): episode_start = transition_terminal_index[i] + 1 episode_end = transition_terminal_index[i + 1] for j in range(episode_start, episode_end + 1): if len(samples.terminals[j]) != seq_len: continue state = dict_to_np(samples.states[j], np_size=gym_env.state_dim, key_offset=0) action = dict_to_np( samples.actions[j], np_size=gym_env.action_dim, key_offset=gym_env.state_dim, ) next_actions = np.float32([ dict_to_np( samples.next_actions[j][k], np_size=gym_env.action_dim, key_offset=gym_env.state_dim, ) for k in range(seq_len) ]) next_states = np.float32([ dict_to_np( samples.next_states[j][k], np_size=gym_env.state_dim, key_offset=0, ) for k in range(seq_len) ]) rewards = np.float32(samples.rewards[j]) terminals = np.float32(samples.terminals[j]) not_terminals = np.logical_not(terminals) mdnrnn_state = np.vstack((state, next_states))[:-1] mdnrnn_action = np.vstack((action, next_actions))[:-1] assert mdnrnn_state.shape == (seq_len, gym_env.state_dim) assert mdnrnn_action.shape == (seq_len, gym_env.action_dim) assert rewards.shape == (seq_len, ) assert next_states.shape == (seq_len, gym_env.state_dim) assert not_terminals.shape == (seq_len, ) replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals) return replay_buffer