Exemplo n.º 1
0
    def test_mdrnn_learning(self):
        num_epochs = 300
        num_episodes = 400
        batch_size = 200
        action_dim = 2
        seq_len = 5
        state_dim = 2
        simulated_num_gaussian = 2
        mdrnn_num_gaussian = 2
        simulated_hidden_size = 3
        mdrnn_hidden_size = 10
        mdrnn_hidden_layer = 1
        adam_lr = 0.01
        cur_state_mem = numpy.zeros((num_episodes, seq_len, state_dim))
        next_state_mem = numpy.zeros((num_episodes, seq_len, state_dim))
        action_mem = numpy.zeros((num_episodes, seq_len, action_dim))
        reward_mem = numpy.zeros((num_episodes, seq_len))
        terminal_mem = numpy.zeros((num_episodes, seq_len))
        next_mus_mem = numpy.zeros(
            (num_episodes, seq_len, simulated_num_gaussian, state_dim))

        swm = SimulatedWorldModel(
            action_dim=action_dim,
            state_dim=state_dim,
            num_gaussian=simulated_num_gaussian,
            lstm_num_layer=1,
            lstm_hidden_dim=simulated_hidden_size,
        )

        actions = torch.eye(action_dim)
        for e in range(num_episodes):
            swm.init_hidden(batch_size=1)
            next_state = torch.randn((1, 1, state_dim))
            for s in range(seq_len):
                cur_state = next_state

                action = torch.tensor(
                    actions[numpy.random.randint(action_dim)]).view(
                        1, 1, action_dim)
                next_mus, reward = swm(action, cur_state)
                terminal = 0
                if s == seq_len - 1:
                    terminal = 1

                next_pi = torch.ones(
                    simulated_num_gaussian) / simulated_num_gaussian
                index = Categorical(next_pi).sample((1, )).long().item()
                next_state = next_mus[0, 0, index].view(1, 1, state_dim)

                print(
                    "{} cur_state: {}, action: {}, next_state: {}, reward: {}, terminal: {}"
                    .format(e, cur_state, action, next_state, reward,
                            terminal))
                print("next_pi: {}, sampled index: {}".format(next_pi, index))
                print("next_mus:", next_mus, "\n")

                cur_state_mem[e, s, :] = cur_state.detach().numpy()
                action_mem[e, s, :] = action.numpy()
                reward_mem[e, s] = reward.detach().numpy()
                terminal_mem[e, s] = terminal
                next_state_mem[e, s, :] = next_state.detach().numpy()
                next_mus_mem[e, s, :, :] = next_mus.detach().numpy()

        mdrnn = MDRNN(
            latents=state_dim,
            actions=action_dim,
            gaussians=mdrnn_num_gaussian,
            hiddens=mdrnn_hidden_size,
            layers=mdrnn_hidden_layer,
        )
        mdrnn.train()
        optimizer = torch.optim.Adam(mdrnn.parameters(), lr=adam_lr)
        num_batch = num_episodes // batch_size
        earlystopping = EarlyStopping('min', patience=30)

        cum_loss = []
        cum_gmm = []
        cum_bce = []
        cum_mse = []
        for e in range(num_epochs):
            for i in range(0, num_batch):
                mdrnn.init_hidden(batch_size=batch_size)
                optimizer.zero_grad()
                sample_indices = numpy.random.randint(num_episodes,
                                                      size=batch_size)

                obs, action, reward, terminal, next_obs = \
                    cur_state_mem[sample_indices], \
                    action_mem[sample_indices], \
                    reward_mem[sample_indices], \
                    terminal_mem[sample_indices], \
                    next_state_mem[sample_indices]
                obs, action, reward, terminal, next_obs = \
                    torch.tensor(obs, dtype=torch.float), \
                    torch.tensor(action, dtype=torch.float), \
                    torch.tensor(reward, dtype=torch.float), \
                    torch.tensor(terminal, dtype=torch.float), \
                    torch.tensor(next_obs, dtype=torch.float)

                print("learning at epoch {} step {} best score {} counter {}".
                      format(e, i, earlystopping.best,
                             earlystopping.num_bad_epochs))
                losses = self.get_loss(obs, action, reward, terminal, next_obs,
                                       state_dim, mdrnn)
                losses['loss'].backward()
                optimizer.step()

                cum_loss += [losses['loss'].item()]
                cum_gmm += [losses['gmm'].item()]
                cum_bce += [losses['bce'].item()]
                cum_mse += [losses['mse'].item()]
                print(
                    "loss={loss:10.6f} bce={bce:10.6f} gmm={gmm:10.6f} mse={mse:10.6f}"
                    .format(
                        loss=losses['loss'],
                        bce=losses['bce'],
                        gmm=losses['gmm'],
                        mse=losses['mse'],
                    ))
                print(
                    "cum loss={loss:10.6f} cum bce={bce:10.6f} cum gmm={gmm:10.6f} cum mse={mse:10.6f}"
                    .format(
                        loss=numpy.mean(cum_loss),
                        bce=numpy.mean(cum_bce),
                        gmm=numpy.mean(cum_gmm),
                        mse=numpy.mean(cum_mse),
                    ))

                print()

            earlystopping.step(numpy.mean(cum_loss[-num_batch:]))
            if numpy.mean(cum_loss[-num_batch:]) < -3. and earlystopping.stop:
                break

        assert numpy.mean(cum_loss[-num_batch:]) < -3.

        sample_indices = [0]
        mdrnn.init_hidden(batch_size=len(sample_indices))
        mdrnn.eval()
        obs, action, reward, terminal, next_obs = \
            cur_state_mem[sample_indices], \
            action_mem[sample_indices], \
            reward_mem[sample_indices], \
            terminal_mem[sample_indices], \
            next_state_mem[sample_indices]
        obs, action, reward, terminal, next_obs = \
            torch.tensor(obs, dtype=torch.float), \
            torch.tensor(action, dtype=torch.float), \
            torch.tensor(reward, dtype=torch.float), \
            torch.tensor(terminal, dtype=torch.float), \
            torch.tensor(next_obs, dtype=torch.float)
        transpose_obs, transpose_action, transpose_reward, transpose_terminal, transpose_next_obs = \
            self.transpose(obs, action, reward, terminal, next_obs)
        mus, sigmas, logpi, rs, ds = mdrnn(transpose_action, transpose_obs)
        pi = torch.exp(logpi)
        gl = gmm_loss(transpose_next_obs, mus, sigmas, logpi)
        print(gl)

        print()
Exemplo n.º 2
0
if exists(rnn_file) and not args.noreload:
    rnn_state = torch.load(rnn_file)
    print("Loading MDRNN at epoch {} "
          "with test error {}".format(rnn_state["epoch"],
                                      rnn_state["precision"]))
    mdrnn.load_state_dict(rnn_state["state_dict"])
    optimizer.load_state_dict(rnn_state["optimizer"])
    # scheduler.load_state_dict(state['scheduler'])
    # earlystopping.load_state_dict(state['earlystopping'])

test_loader = DataLoader(_RolloutDataset('datasets/carracing2_rnn',
                                         train=False),
                         batch_size=BSIZE,
                         num_workers=8)

mdrnn.eval()
loader = test_loader

mu_record = []
sigma_record = []
pi_record = []
output_record = []
for i, data in enumerate(loader):

    data = data.cuda()
    input = data[:, :-1, :]
    output = data[:, 1:, :32]
    with torch.no_grad():
        mu, sigma, pi = mdrnn(input, train=False)
        sigma = torch.exp(sigma).transpose(1, 0).view(8, 999, 32, 5)
        pi = torch.softmax(pi, dim=-1).transpose(1, 0).view(8, 999, 32, 5)