def train_mdnrnn(
    env: EnvWrapper,
    trainer: MDNRNNTrainer,
    trainer_preprocessor,
    num_train_transitions: int,
    seq_len: int,
    batch_size: int,
    num_train_epochs: int,
    # for optional validation
    test_replay_buffer=None,
):
    train_replay_buffer = ReplayBuffer(
        replay_capacity=num_train_transitions,
        batch_size=batch_size,
        stack_size=seq_len,
        return_everything_as_stack=True,
    )
    random_policy = make_random_policy_for_env(env)
    agent = Agent.create_for_env(env, policy=random_policy)
    fill_replay_buffer(env, train_replay_buffer, num_train_transitions, agent)
    num_batch_per_epoch = train_replay_buffer.size // batch_size

    logger.info("Made RBs, starting to train now!")
    optimizer = trainer.configure_optimizers()[0]
    for _ in range(num_train_epochs):
        for i in range(num_batch_per_epoch):
            batch = train_replay_buffer.sample_transition_batch(batch_size=batch_size)
            preprocessed_batch = trainer_preprocessor(batch)
            loss = next(trainer.train_step_gen(preprocessed_batch, i))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # validation
        if test_replay_buffer is not None:
            with torch.no_grad():
                trainer.memory_network.mdnrnn.eval()
                test_batch = test_replay_buffer.sample_transition_batch(
                    batch_size=batch_size
                )
                preprocessed_test_batch = trainer_preprocessor(test_batch)
                valid_losses = trainer.get_loss(preprocessed_test_batch)
                trainer.memory_network.mdnrnn.train()
    return trainer
Exemple #2
0
    def _test_mdnrnn_simulate_world(self, use_gpu=False):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussians = 2
        mdrnn_num_gaussians = 2
        simulated_num_hidden_layers = 1
        simulated_num_hiddens = 3
        mdnrnn_num_hidden_layers = 1
        mdnrnn_num_hiddens = 10
        adam_lr = 0.01

        replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes)
        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussians=simulated_num_gaussians,
            lstm_num_hidden_layers=simulated_num_hidden_layers,
            lstm_num_hiddens=simulated_num_hiddens,
        )

        possible_actions = torch.eye(action_dim)
        for _ in range(num_episodes):
            cur_state_mem = torch.zeros((seq_len, state_dim))
            next_state_mem = torch.zeros((seq_len, state_dim))
            action_mem = torch.zeros((seq_len, action_dim))
            reward_mem = torch.zeros(seq_len)
            not_terminal_mem = torch.zeros(seq_len)
            next_mus_mem = torch.zeros((seq_len, simulated_num_gaussians, state_dim))

            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state
                action = possible_actions[np.random.randint(action_dim)].view(
                    1, 1, action_dim
                )
                next_mus, reward = swm(action, cur_state)

                not_terminal = 1
                if s == seq_len - 1:
                    not_terminal = 0

                # randomly draw for next state
                next_pi = torch.ones(simulated_num_gaussians) / simulated_num_gaussians
                index = Categorical(next_pi).sample((1,)).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                cur_state_mem[s] = cur_state.detach()
                action_mem[s] = action
                reward_mem[s] = reward.detach()
                not_terminal_mem[s] = not_terminal
                next_state_mem[s] = next_state.detach()
                next_mus_mem[s] = next_mus.detach()

            replay_buffer.insert_into_memory(
                cur_state_mem, action_mem, next_state_mem, reward_mem, not_terminal_mem
            )

        num_batch = num_episodes // batch_size
        mdnrnn_params = MDNRNNTrainerParameters(
            hidden_size=mdnrnn_num_hiddens,
            num_hidden_layers=mdnrnn_num_hidden_layers,
            minibatch_size=batch_size,
            learning_rate=adam_lr,
            num_gaussians=mdrnn_num_gaussians,
        )
        mdnrnn_net = MemoryNetwork(
            state_dim=state_dim,
            action_dim=action_dim,
            num_hiddens=mdnrnn_params.hidden_size,
            num_hidden_layers=mdnrnn_params.num_hidden_layers,
            num_gaussians=mdnrnn_params.num_gaussians,
        )
        if use_gpu:
            mdnrnn_net = mdnrnn_net.cuda()
        trainer = MDNRNNTrainer(
            memory_network=mdnrnn_net, params=mdnrnn_params, cum_loss_hist=num_batch
        )
        reporter = WorldModelReporter(report_interval=1)
        trainer.set_reporter(reporter)

        optimizer = trainer.configure_optimizers()[0]
        for e in range(num_epochs):
            for i in range(num_batch):
                training_batch = replay_buffer.sample_memories(
                    batch_size, use_gpu=use_gpu
                )
                optimizer.zero_grad()
                loss = next(trainer.train_step_gen(training_batch, i))
                loss.backward()
                optimizer.step()

                logger.info(
                    "{}-th epoch, {}-th minibatch: \n"
                    "loss={}, bce={}, gmm={}, mse={} \n"
                    "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format(
                        e,
                        i,
                        reporter.loss.values[-1],
                        reporter.bce.values[-1],
                        reporter.gmm.values[-1],
                        reporter.mse.values[-1],
                        np.mean(reporter.loss.values[-100:]),
                        np.mean(reporter.bce.values[-100:]),
                        np.mean(reporter.gmm.values[-100:]),
                        np.mean(reporter.mse.values[-100:]),
                    )
                )

                if (
                    np.mean(reporter.loss.values[-100:]) < 0
                    and np.mean(reporter.gmm.values[-100:]) < -3.0
                    and np.mean(reporter.bce.values[-100:]) < 0.6
                    and np.mean(reporter.mse.values[-100:]) < 0.2
                ):
                    return

        raise RuntimeError("losses not reduced significantly during training")