Example #1
0
def get_replay_buffer(num_episodes: int, seq_len: int, max_step: int,
                      gym_env: OpenAIGymEnvironment) -> MDNRNNMemoryPool:
    num_transitions = num_episodes * max_step
    replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions)
    for (
            mdnrnn_state,
            mdnrnn_action,
            rewards,
            next_states,
            _,
            not_terminals,
            _,
            _,
    ) in multi_step_sample_generator(
            gym_env,
            num_transitions=num_transitions,
            max_steps=max_step,
            multi_steps=seq_len,
            include_shorter_samples_at_start=False,
            include_shorter_samples_at_end=False,
    ):
        mdnrnn_state, mdnrnn_action, next_states, rewards, not_terminals = (
            torch.tensor(mdnrnn_state),
            torch.tensor(mdnrnn_action),
            torch.tensor(next_states),
            torch.tensor(rewards),
            torch.tensor(not_terminals),
        )
        replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action,
                                         next_states, rewards, not_terminals)

    return replay_buffer
Example #2
0
def get_replay_buffer(
    num_episodes: int,
    seq_len: int,
    max_step: Optional[int],
    gym_env: OpenAIGymEnvironment,
):
    num_transitions = num_episodes * max_step  # type: ignore
    replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions)
    for (
            mdnrnn_state,
            mdnrnn_action,
            rewards,
            next_states,
            _,
            not_terminals,
            _,
            _,
    ) in multi_step_sample_generator(
            gym_env,
            num_transitions=num_transitions,
            max_steps=max_step,
            multi_steps=seq_len,
            ignore_shorter_samples_at_start=True,
            ignore_shorter_samples_at_end=True,
    ):
        replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action,
                                         next_states, rewards, not_terminals)

    return replay_buffer
Example #3
0
    def test_mdnrnn_simulate_world(self):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussians = 2
        mdrnn_num_gaussians = 2
        simulated_num_hidden_layers = 1
        simulated_num_hiddens = 3
        mdnrnn_num_hidden_layers = 1
        mdnrnn_num_hiddens = 10
        adam_lr = 0.01

        replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes)
        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussians=simulated_num_gaussians,
            lstm_num_hidden_layers=simulated_num_hidden_layers,
            lstm_num_hiddens=simulated_num_hiddens,
        )

        possible_actions = torch.eye(action_dim)
        for _ in range(num_episodes):
            cur_state_mem = np.zeros((seq_len, state_dim))
            next_state_mem = np.zeros((seq_len, state_dim))
            action_mem = np.zeros((seq_len, action_dim))
            reward_mem = np.zeros(seq_len)
            not_terminal_mem = np.zeros(seq_len)
            next_mus_mem = np.zeros(
                (seq_len, simulated_num_gaussians, state_dim))

            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state
                action = possible_actions[np.random.randint(action_dim)].view(
                    1, 1, action_dim)
                next_mus, reward = swm(action, cur_state)

                not_terminal = 1
                if s == seq_len - 1:
                    not_terminal = 0

                # randomly draw for next state
                next_pi = torch.ones(
                    simulated_num_gaussians) / simulated_num_gaussians
                index = Categorical(next_pi).sample((1, )).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                cur_state_mem[s] = cur_state.detach().numpy()
                action_mem[s] = action.numpy()
                reward_mem[s] = reward.detach().numpy()
                not_terminal_mem[s] = not_terminal
                next_state_mem[s] = next_state.detach().numpy()
                next_mus_mem[s] = next_mus.detach().numpy()

            replay_buffer.insert_into_memory(cur_state_mem, action_mem,
                                             next_state_mem, reward_mem,
                                             not_terminal_mem)

        num_batch = num_episodes // batch_size
        mdnrnn_params = MDNRNNParameters(
            hidden_size=mdnrnn_num_hiddens,
            num_hidden_layers=mdnrnn_num_hidden_layers,
            minibatch_size=batch_size,
            learning_rate=adam_lr,
            num_gaussians=mdrnn_num_gaussians,
        )
        mdnrnn_net = MemoryNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            num_hiddens=mdnrnn_params.hidden_size,
            num_hidden_layers=mdnrnn_params.num_hidden_layers,
            num_gaussians=mdnrnn_params.num_gaussians,
        )
        trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net,
                                params=mdnrnn_params,
                                cum_loss_hist=num_batch)

        for e in range(num_epochs):
            for i in range(num_batch):
                training_batch = replay_buffer.sample_memories(batch_size)
                losses = trainer.train(training_batch)
                logger.info(
                    "{}-th epoch, {}-th minibatch: \n"
                    "loss={}, bce={}, gmm={}, mse={} \n"
                    "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                        e,
                        i,
                        losses["loss"],
                        losses["bce"],
                        losses["gmm"],
                        losses["mse"],
                        np.mean(trainer.cum_loss),
                        np.mean(trainer.cum_bce),
                        np.mean(trainer.cum_gmm),
                        np.mean(trainer.cum_mse),
                    ))

                if (np.mean(trainer.cum_loss) < 0
                        and np.mean(trainer.cum_gmm) < -3.0
                        and np.mean(trainer.cum_bce) < 0.6
                        and np.mean(trainer.cum_mse) < 0.2):
                    return

        assert False, "losses not reduced significantly during training"
Example #4
0
def get_replay_buffer(num_episodes, seq_len, max_step, gym_env):
    num_transitions = num_episodes * max_step
    samples = gym_env.generate_random_samples(
        num_transitions=num_transitions,
        use_continuous_action=True,
        max_step=max_step,
        multi_steps=seq_len,
    )

    replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_transitions)
    # convert RL sample format to MDN-RNN sample format
    transition_terminal_index = [-1]
    for i in range(1, len(samples.mdp_ids)):
        if samples.terminals[i][0] is True:
            assert len(samples.terminals[i]) == 1
            transition_terminal_index.append(i)

    for i in range(len(transition_terminal_index) - 1):
        episode_start = transition_terminal_index[i] + 1
        episode_end = transition_terminal_index[i + 1]

        for j in range(episode_start, episode_end + 1):
            if len(samples.terminals[j]) != seq_len:
                continue
            state = dict_to_np(samples.states[j],
                               np_size=gym_env.state_dim,
                               key_offset=0)
            action = dict_to_np(
                samples.actions[j],
                np_size=gym_env.action_dim,
                key_offset=gym_env.state_dim,
            )
            next_actions = np.float32([
                dict_to_np(
                    samples.next_actions[j][k],
                    np_size=gym_env.action_dim,
                    key_offset=gym_env.state_dim,
                ) for k in range(seq_len)
            ])
            next_states = np.float32([
                dict_to_np(
                    samples.next_states[j][k],
                    np_size=gym_env.state_dim,
                    key_offset=0,
                ) for k in range(seq_len)
            ])
            rewards = np.float32(samples.rewards[j])
            terminals = np.float32(samples.terminals[j])
            not_terminals = np.logical_not(terminals)
            mdnrnn_state = np.vstack((state, next_states))[:-1]
            mdnrnn_action = np.vstack((action, next_actions))[:-1]

            assert mdnrnn_state.shape == (seq_len, gym_env.state_dim)
            assert mdnrnn_action.shape == (seq_len, gym_env.action_dim)
            assert rewards.shape == (seq_len, )
            assert next_states.shape == (seq_len, gym_env.state_dim)
            assert not_terminals.shape == (seq_len, )

            replay_buffer.insert_into_memory(mdnrnn_state, mdnrnn_action,
                                             next_states, rewards,
                                             not_terminals)

    return replay_buffer