def test_mdrnn_learning(self): num_epochs = 300 num_episodes = 400 batch_size = 200 action_dim = 2 seq_len = 5 state_dim = 2 simulated_num_gaussian = 2 mdrnn_num_gaussian = 2 simulated_hidden_size = 3 mdrnn_hidden_size = 10 mdrnn_hidden_layer = 1 adam_lr = 0.01 cur_state_mem = numpy.zeros((num_episodes, seq_len, state_dim)) next_state_mem = numpy.zeros((num_episodes, seq_len, state_dim)) action_mem = numpy.zeros((num_episodes, seq_len, action_dim)) reward_mem = numpy.zeros((num_episodes, seq_len)) terminal_mem = numpy.zeros((num_episodes, seq_len)) next_mus_mem = numpy.zeros( (num_episodes, seq_len, simulated_num_gaussian, state_dim)) swm = SimulatedWorldModel( action_dim=action_dim, state_dim=state_dim, num_gaussian=simulated_num_gaussian, lstm_num_layer=1, lstm_hidden_dim=simulated_hidden_size, ) actions = torch.eye(action_dim) for e in range(num_episodes): swm.init_hidden(batch_size=1) next_state = torch.randn((1, 1, state_dim)) for s in range(seq_len): cur_state = next_state action = torch.tensor( actions[numpy.random.randint(action_dim)]).view( 1, 1, action_dim) next_mus, reward = swm(action, cur_state) terminal = 0 if s == seq_len - 1: terminal = 1 next_pi = torch.ones( simulated_num_gaussian) / simulated_num_gaussian index = Categorical(next_pi).sample((1, )).long().item() next_state = next_mus[0, 0, index].view(1, 1, state_dim) print( "{} cur_state: {}, action: {}, next_state: {}, reward: {}, terminal: {}" .format(e, cur_state, action, next_state, reward, terminal)) print("next_pi: {}, sampled index: {}".format(next_pi, index)) print("next_mus:", next_mus, "\n") cur_state_mem[e, s, :] = cur_state.detach().numpy() action_mem[e, s, :] = action.numpy() reward_mem[e, s] = reward.detach().numpy() terminal_mem[e, s] = terminal next_state_mem[e, s, :] = next_state.detach().numpy() next_mus_mem[e, s, :, :] = next_mus.detach().numpy() mdrnn = MDRNN( latents=state_dim, actions=action_dim, gaussians=mdrnn_num_gaussian, hiddens=mdrnn_hidden_size, layers=mdrnn_hidden_layer, ) mdrnn.train() optimizer = torch.optim.Adam(mdrnn.parameters(), lr=adam_lr) num_batch = num_episodes // batch_size earlystopping = EarlyStopping('min', patience=30) cum_loss = [] cum_gmm = [] cum_bce = [] cum_mse = [] for e in range(num_epochs): for i in range(0, num_batch): mdrnn.init_hidden(batch_size=batch_size) optimizer.zero_grad() sample_indices = numpy.random.randint(num_episodes, size=batch_size) obs, action, reward, terminal, next_obs = \ cur_state_mem[sample_indices], \ action_mem[sample_indices], \ reward_mem[sample_indices], \ terminal_mem[sample_indices], \ next_state_mem[sample_indices] obs, action, reward, terminal, next_obs = \ torch.tensor(obs, dtype=torch.float), \ torch.tensor(action, dtype=torch.float), \ torch.tensor(reward, dtype=torch.float), \ torch.tensor(terminal, dtype=torch.float), \ torch.tensor(next_obs, dtype=torch.float) print("learning at epoch {} step {} best score {} counter {}". format(e, i, earlystopping.best, earlystopping.num_bad_epochs)) losses = self.get_loss(obs, action, reward, terminal, next_obs, state_dim, mdrnn) losses['loss'].backward() optimizer.step() cum_loss += [losses['loss'].item()] cum_gmm += [losses['gmm'].item()] cum_bce += [losses['bce'].item()] cum_mse += [losses['mse'].item()] print( "loss={loss:10.6f} bce={bce:10.6f} gmm={gmm:10.6f} mse={mse:10.6f}" .format( loss=losses['loss'], bce=losses['bce'], gmm=losses['gmm'], mse=losses['mse'], )) print( "cum loss={loss:10.6f} cum bce={bce:10.6f} cum gmm={gmm:10.6f} cum mse={mse:10.6f}" .format( loss=numpy.mean(cum_loss), bce=numpy.mean(cum_bce), gmm=numpy.mean(cum_gmm), mse=numpy.mean(cum_mse), )) print() earlystopping.step(numpy.mean(cum_loss[-num_batch:])) if numpy.mean(cum_loss[-num_batch:]) < -3. and earlystopping.stop: break assert numpy.mean(cum_loss[-num_batch:]) < -3. sample_indices = [0] mdrnn.init_hidden(batch_size=len(sample_indices)) mdrnn.eval() obs, action, reward, terminal, next_obs = \ cur_state_mem[sample_indices], \ action_mem[sample_indices], \ reward_mem[sample_indices], \ terminal_mem[sample_indices], \ next_state_mem[sample_indices] obs, action, reward, terminal, next_obs = \ torch.tensor(obs, dtype=torch.float), \ torch.tensor(action, dtype=torch.float), \ torch.tensor(reward, dtype=torch.float), \ torch.tensor(terminal, dtype=torch.float), \ torch.tensor(next_obs, dtype=torch.float) transpose_obs, transpose_action, transpose_reward, transpose_terminal, transpose_next_obs = \ self.transpose(obs, action, reward, terminal, next_obs) mus, sigmas, logpi, rs, ds = mdrnn(transpose_action, transpose_obs) pi = torch.exp(logpi) gl = gmm_loss(transpose_next_obs, mus, sigmas, logpi) print(gl) print()
if exists(rnn_file) and not args.noreload: rnn_state = torch.load(rnn_file) print("Loading MDRNN at epoch {} " "with test error {}".format(rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) optimizer.load_state_dict(rnn_state["optimizer"]) # scheduler.load_state_dict(state['scheduler']) # earlystopping.load_state_dict(state['earlystopping']) test_loader = DataLoader(_RolloutDataset('datasets/carracing2_rnn', train=False), batch_size=BSIZE, num_workers=8) mdrnn.eval() loader = test_loader mu_record = [] sigma_record = [] pi_record = [] output_record = [] for i, data in enumerate(loader): data = data.cuda() input = data[:, :-1, :] output = data[:, 1:, :32] with torch.no_grad(): mu, sigma, pi = mdrnn(input, train=False) sigma = torch.exp(sigma).transpose(1, 0).view(8, 999, 32, 5) pi = torch.softmax(pi, dim=-1).transpose(1, 0).view(8, 999, 32, 5)