示例#1
0
def create_trainer(
    params: OpenAiGymParameters, env: OpenAIGymEnvironment, use_gpu: bool
):
    assert params.mdnrnn is not None
    assert params.run_details.max_steps is not None
    mdnrnn_params = params.mdnrnn
    mdnrnn_net = MemoryNetwork(
        state_dim=env.state_dim,
        action_dim=env.action_dim,
        num_hiddens=mdnrnn_params.hidden_size,
        num_hidden_layers=mdnrnn_params.num_hidden_layers,
        num_gaussians=mdnrnn_params.num_gaussians,
    )
    if use_gpu:
        mdnrnn_net = mdnrnn_net.cuda()

    cum_loss_hist_len = (
        params.run_details.num_train_episodes
        * params.run_details.max_steps
        // mdnrnn_params.minibatch_size
    )
    trainer = MDNRNNTrainer(
        mdnrnn_network=mdnrnn_net, params=mdnrnn_params, cum_loss_hist=cum_loss_hist_len
    )
    return trainer
示例#2
0
def create_world_model_trainer(
    env: OpenAIGymEnvironment, mdnrnn_params: MDNRNNParameters, use_gpu: bool
) -> MDNRNNTrainer:
    mdnrnn_net = MemoryNetwork(
        state_dim=env.state_dim,
        action_dim=env.action_dim,
        num_hiddens=mdnrnn_params.hidden_size,
        num_hidden_layers=mdnrnn_params.num_hidden_layers,
        num_gaussians=mdnrnn_params.num_gaussians,
    )
    if use_gpu:
        mdnrnn_net = mdnrnn_net.cuda()
    mdnrnn_trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net, params=mdnrnn_params)
    return mdnrnn_trainer
示例#3
0
def create_trainer(params, env):
    mdnrnn_params = MDNRNNParameters(**params["mdnrnn"])
    mdnrnn_net = MemoryNetwork(
        state_dim=env.state_dim,
        action_dim=env.action_dim,
        num_hiddens=mdnrnn_params.hidden_size,
        num_hidden_layers=mdnrnn_params.num_hidden_layers,
        num_gaussians=mdnrnn_params.num_gaussians,
    )
    cum_loss_hist_len = (params["run_details"]["num_train_episodes"] *
                         params["run_details"]["max_steps"] //
                         mdnrnn_params.minibatch_size)
    trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net,
                            params=mdnrnn_params,
                            cum_loss_hist=cum_loss_hist_len)
    return trainer
示例#4
0
def create_trainer(params: Dict, env: OpenAIGymEnvironment, use_gpu: bool):
    mdnrnn_params = MDNRNNParameters(**params["mdnrnn"])
    mdnrnn_net = MemoryNetwork(
        state_dim=env.state_dim,
        action_dim=env.action_dim,
        num_hiddens=mdnrnn_params.hidden_size,
        num_hidden_layers=mdnrnn_params.num_hidden_layers,
        num_gaussians=mdnrnn_params.num_gaussians,
    )
    if use_gpu and torch.cuda.is_available():
        mdnrnn_net = mdnrnn_net.cuda()

    cum_loss_hist_len = (params["run_details"]["num_train_episodes"] *
                         params["run_details"]["max_steps"] //
                         mdnrnn_params.minibatch_size)
    trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net,
                            params=mdnrnn_params,
                            cum_loss_hist=cum_loss_hist_len)
    return trainer
示例#5
0
def train_sgd(
    gym_env: OpenAIGymEnvironment,
    trainer: MDNRNNTrainer,
    use_gpu: bool,
    test_run_name: str,
    minibatch_size: int,
    run_details: OpenAiRunDetails,
):
    assert run_details.max_steps is not None
    train_replay_buffer = get_replay_buffer(
        run_details.num_train_episodes,
        run_details.seq_len,
        run_details.max_steps,
        gym_env,
    )
    valid_replay_buffer = get_replay_buffer(
        run_details.num_test_episodes,
        run_details.seq_len,
        run_details.max_steps,
        gym_env,
    )
    test_replay_buffer = get_replay_buffer(
        run_details.num_test_episodes,
        run_details.seq_len,
        run_details.max_steps,
        gym_env,
    )
    valid_loss_history = []

    num_batch_per_epoch = train_replay_buffer.memory_size // minibatch_size
    logger.info(
        "Collected data {} transitions.\n"
        "Training will take {} epochs, with each epoch having {} mini-batches"
        " and each mini-batch having {} samples".format(
            train_replay_buffer.memory_size,
            run_details.train_epochs,
            num_batch_per_epoch,
            minibatch_size,
        ))

    for i_epoch in range(run_details.train_epochs):
        for i_batch in range(num_batch_per_epoch):
            training_batch = train_replay_buffer.sample_memories(
                minibatch_size, use_gpu=use_gpu, batch_first=True)
            losses = trainer.train(training_batch, batch_first=True)
            logger.info(
                "{}-th epoch, {}-th minibatch: \n"
                "loss={}, bce={}, gmm={}, mse={} \n"
                "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                    i_epoch,
                    i_batch,
                    losses["loss"],
                    losses["bce"],
                    losses["gmm"],
                    losses["mse"],
                    np.mean(trainer.cum_loss),
                    np.mean(trainer.cum_bce),
                    np.mean(trainer.cum_gmm),
                    np.mean(trainer.cum_mse),
                ))

        trainer.mdnrnn.mdnrnn.eval()
        valid_batch = valid_replay_buffer.sample_memories(
            valid_replay_buffer.memory_size, use_gpu=use_gpu, batch_first=True)
        valid_losses = trainer.get_loss(valid_batch,
                                        state_dim=gym_env.state_dim,
                                        batch_first=True)
        valid_losses = loss_to_num(valid_losses)
        valid_loss_history.append(valid_losses)
        trainer.mdnrnn.mdnrnn.train()
        logger.info(
            "{}-th epoch, validate loss={}, bce={}, gmm={}, mse={}".format(
                i_epoch,
                valid_losses["loss"],
                valid_losses["bce"],
                valid_losses["gmm"],
                valid_losses["mse"],
            ))
        latest_loss = valid_loss_history[-1]["loss"]
        recent_valid_loss_hist = valid_loss_history[-1 - run_details.
                                                    early_stopping_patience:-1]
        # earlystopping
        if len(valid_loss_history
               ) > run_details.early_stopping_patience and all(
                   (latest_loss >= v["loss"] for v in recent_valid_loss_hist)):
            break

    trainer.mdnrnn.mdnrnn.eval()
    test_batch = test_replay_buffer.sample_memories(
        test_replay_buffer.memory_size, use_gpu=use_gpu, batch_first=True)
    test_losses = trainer.get_loss(test_batch,
                                   state_dim=gym_env.state_dim,
                                   batch_first=True)
    test_losses = loss_to_num(test_losses)
    logger.info("Test loss: {}, bce={}, gmm={}, mse={}".format(
        test_losses["loss"],
        test_losses["bce"],
        test_losses["gmm"],
        test_losses["mse"],
    ))
    logger.info("Valid loss history: {}".format(valid_loss_history))
    return test_losses, valid_loss_history, trainer
示例#6
0
def create_embed_rl_dataset(
    gym_env: OpenAIGymEnvironment,
    trainer: MDNRNNTrainer,
    dataset: RLDataset,
    use_gpu: bool,
    run_details: OpenAiRunDetails,
):
    assert run_details.max_steps is not None
    old_mdnrnn_mode = trainer.mdnrnn.mdnrnn.training
    trainer.mdnrnn.mdnrnn.eval()
    num_transitions = run_details.num_state_embed_episodes * run_details.max_steps
    device = torch.device("cuda") if use_gpu else torch.device(
        "cpu")  # type: ignore

    (
        state_batch,
        action_batch,
        reward_batch,
        next_state_batch,
        next_action_batch,
        not_terminal_batch,
        step_batch,
        next_step_batch,
    ) = map(
        list,
        zip(*multi_step_sample_generator(
            gym_env=gym_env,
            num_transitions=num_transitions,
            max_steps=run_details.max_steps,
            # +1 because MDNRNN embeds the first seq_len steps and then
            # the embedded state will be concatenated with the last step
            multi_steps=run_details.seq_len + 1,
            include_shorter_samples_at_start=True,
            include_shorter_samples_at_end=False,
        )),
    )

    def concat_batch(batch):
        return torch.cat(
            [
                torch.tensor(np.expand_dims(x, axis=1),
                             dtype=torch.float,
                             device=device) for x in batch
            ],
            dim=1,
        )

    # shape: seq_len x batch_size x feature_dim
    mdnrnn_state = concat_batch(state_batch)
    next_mdnrnn_state = concat_batch(next_state_batch)
    mdnrnn_action = concat_batch(action_batch)
    next_mdnrnn_action = concat_batch(next_action_batch)

    mdnrnn_input = rlt.PreprocessedStateAction.from_tensors(
        state=mdnrnn_state, action=mdnrnn_action)
    next_mdnrnn_input = rlt.PreprocessedStateAction.from_tensors(
        state=next_mdnrnn_state, action=next_mdnrnn_action)
    # batch-compute state embedding
    mdnrnn_output = trainer.mdnrnn(mdnrnn_input)
    next_mdnrnn_output = trainer.mdnrnn(next_mdnrnn_input)

    for i in range(len(state_batch)):
        # Embed the state as the hidden layer's output
        # until the previous step + current state
        hidden_idx = 0 if step_batch[
            i] == 1 else step_batch[i] - 2  # type: ignore
        next_hidden_idx = next_step_batch[i] - 2  # type: ignore
        hidden_embed = (
            mdnrnn_output.all_steps_lstm_hidden[hidden_idx,
                                                i, :].squeeze().detach().cpu())
        state_embed = torch.cat(
            (hidden_embed, torch.tensor(state_batch[i][hidden_idx + 1])
             )  # type: ignore
        )
        next_hidden_embed = (next_mdnrnn_output.all_steps_lstm_hidden[
            next_hidden_idx, i, :].squeeze().detach().cpu())
        next_state_embed = torch.cat((
            next_hidden_embed,
            torch.tensor(next_state_batch[i][next_hidden_idx +
                                             1]),  # type: ignore
        ))

        logger.debug(
            "create_embed_rl_dataset:\nstate batch\n{}\naction batch\n{}\nlast "
            "action: {},reward: {}\nstate embed {}\nnext state embed {}\n".
            format(
                state_batch[i][:hidden_idx + 1],  # type: ignore
                action_batch[i][:hidden_idx + 1],  # type: ignore
                action_batch[i][hidden_idx + 1],  # type: ignore
                reward_batch[i][hidden_idx + 1],  # type: ignore
                state_embed,
                next_state_embed,
            ))

        terminal = 1 - not_terminal_batch[i][hidden_idx + 1]  # type: ignore
        possible_actions, possible_actions_mask = get_possible_actions(
            gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, False)
        possible_next_actions, possible_next_actions_mask = get_possible_actions(
            gym_env, ModelType.PYTORCH_PARAMETRIC_DQN.value, terminal)
        dataset.insert(
            state=state_embed,
            action=torch.tensor(action_batch[i][hidden_idx +
                                                1]),  # type: ignore
            reward=reward_batch[i][hidden_idx + 1],  # type: ignore
            next_state=next_state_embed,
            next_action=torch.tensor(next_action_batch[i][next_hidden_idx +
                                                          1]  # type: ignore
                                     ),
            terminal=torch.tensor(terminal),
            possible_next_actions=possible_next_actions,
            possible_next_actions_mask=possible_next_actions_mask,
            time_diff=torch.tensor(1),
            possible_actions=possible_actions,
            possible_actions_mask=possible_actions_mask,
            policy_id=0,
        )
    logger.info("Insert {} transitions into a state embed dataset".format(
        len(state_batch)))
    trainer.mdnrnn.mdnrnn.train(old_mdnrnn_mode)
    return dataset
示例#7
0
    def test_mdnrnn_simulate_world(self):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussians = 2
        mdrnn_num_gaussians = 2
        simulated_num_hidden_layers = 1
        simulated_num_hiddens = 3
        mdnrnn_num_hidden_layers = 1
        mdnrnn_num_hiddens = 10
        adam_lr = 0.01

        replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes)
        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussians=simulated_num_gaussians,
            lstm_num_hidden_layers=simulated_num_hidden_layers,
            lstm_num_hiddens=simulated_num_hiddens,
        )

        possible_actions = torch.eye(action_dim)
        for _ in range(num_episodes):
            cur_state_mem = np.zeros((seq_len, state_dim))
            next_state_mem = np.zeros((seq_len, state_dim))
            action_mem = np.zeros((seq_len, action_dim))
            reward_mem = np.zeros(seq_len)
            not_terminal_mem = np.zeros(seq_len)
            next_mus_mem = np.zeros(
                (seq_len, simulated_num_gaussians, state_dim))

            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state
                action = possible_actions[np.random.randint(action_dim)].view(
                    1, 1, action_dim)
                next_mus, reward = swm(action, cur_state)

                not_terminal = 1
                if s == seq_len - 1:
                    not_terminal = 0

                # randomly draw for next state
                next_pi = torch.ones(
                    simulated_num_gaussians) / simulated_num_gaussians
                index = Categorical(next_pi).sample((1, )).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                cur_state_mem[s] = cur_state.detach().numpy()
                action_mem[s] = action.numpy()
                reward_mem[s] = reward.detach().numpy()
                not_terminal_mem[s] = not_terminal
                next_state_mem[s] = next_state.detach().numpy()
                next_mus_mem[s] = next_mus.detach().numpy()

            replay_buffer.insert_into_memory(cur_state_mem, action_mem,
                                             next_state_mem, reward_mem,
                                             not_terminal_mem)

        num_batch = num_episodes // batch_size
        mdnrnn_params = MDNRNNParameters(
            hidden_size=mdnrnn_num_hiddens,
            num_hidden_layers=mdnrnn_num_hidden_layers,
            minibatch_size=batch_size,
            learning_rate=adam_lr,
            num_gaussians=mdrnn_num_gaussians,
        )
        mdnrnn_net = MemoryNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            num_hiddens=mdnrnn_params.hidden_size,
            num_hidden_layers=mdnrnn_params.num_hidden_layers,
            num_gaussians=mdnrnn_params.num_gaussians,
        )
        trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net,
                                params=mdnrnn_params,
                                cum_loss_hist=num_batch)

        for e in range(num_epochs):
            for i in range(num_batch):
                training_batch = replay_buffer.sample_memories(batch_size)
                losses = trainer.train(training_batch)
                logger.info(
                    "{}-th epoch, {}-th minibatch: \n"
                    "loss={}, bce={}, gmm={}, mse={} \n"
                    "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                        e,
                        i,
                        losses["loss"],
                        losses["bce"],
                        losses["gmm"],
                        losses["mse"],
                        np.mean(trainer.cum_loss),
                        np.mean(trainer.cum_bce),
                        np.mean(trainer.cum_gmm),
                        np.mean(trainer.cum_mse),
                    ))

                if (np.mean(trainer.cum_loss) < 0
                        and np.mean(trainer.cum_gmm) < -3.0
                        and np.mean(trainer.cum_bce) < 0.6
                        and np.mean(trainer.cum_mse) < 0.2):
                    return

        assert False, "losses not reduced significantly during training"