def train_sgd( gym_env: OpenAIGymEnvironment, trainer: MDNRNNTrainer, use_gpu: bool, test_run_name: str, minibatch_size: int, run_details: OpenAiRunDetails, ): assert run_details.max_steps is not None train_replay_buffer = get_replay_buffer( run_details.num_train_episodes, run_details.seq_len, run_details.max_steps, gym_env, ) valid_replay_buffer = get_replay_buffer( run_details.num_test_episodes, run_details.seq_len, run_details.max_steps, gym_env, ) test_replay_buffer = get_replay_buffer( run_details.num_test_episodes, run_details.seq_len, run_details.max_steps, gym_env, ) valid_loss_history = [] num_batch_per_epoch = train_replay_buffer.memory_size // minibatch_size logger.info( "Collected data {} transitions.\n" "Training will take {} epochs, with each epoch having {} mini-batches" " and each mini-batch having {} samples".format( train_replay_buffer.memory_size, run_details.train_epochs, num_batch_per_epoch, minibatch_size, )) for i_epoch in range(run_details.train_epochs): for i_batch in range(num_batch_per_epoch): training_batch = train_replay_buffer.sample_memories( minibatch_size, use_gpu=use_gpu, batch_first=True) losses = trainer.train(training_batch, batch_first=True) logger.info( "{}-th epoch, {}-th minibatch: \n" "loss={}, bce={}, gmm={}, mse={} \n" "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format( i_epoch, i_batch, losses["loss"], losses["bce"], losses["gmm"], losses["mse"], np.mean(trainer.cum_loss), np.mean(trainer.cum_bce), np.mean(trainer.cum_gmm), np.mean(trainer.cum_mse), )) trainer.mdnrnn.mdnrnn.eval() valid_batch = valid_replay_buffer.sample_memories( valid_replay_buffer.memory_size, use_gpu=use_gpu, batch_first=True) valid_losses = trainer.get_loss(valid_batch, state_dim=gym_env.state_dim, batch_first=True) valid_losses = loss_to_num(valid_losses) valid_loss_history.append(valid_losses) trainer.mdnrnn.mdnrnn.train() logger.info( "{}-th epoch, validate loss={}, bce={}, gmm={}, mse={}".format( i_epoch, valid_losses["loss"], valid_losses["bce"], valid_losses["gmm"], valid_losses["mse"], )) latest_loss = valid_loss_history[-1]["loss"] recent_valid_loss_hist = valid_loss_history[-1 - run_details. early_stopping_patience:-1] # earlystopping if len(valid_loss_history ) > run_details.early_stopping_patience and all( (latest_loss >= v["loss"] for v in recent_valid_loss_hist)): break trainer.mdnrnn.mdnrnn.eval() test_batch = test_replay_buffer.sample_memories( test_replay_buffer.memory_size, use_gpu=use_gpu, batch_first=True) test_losses = trainer.get_loss(test_batch, state_dim=gym_env.state_dim, batch_first=True) test_losses = loss_to_num(test_losses) logger.info("Test loss: {}, bce={}, gmm={}, mse={}".format( test_losses["loss"], test_losses["bce"], test_losses["gmm"], test_losses["mse"], )) logger.info("Valid loss history: {}".format(valid_loss_history)) return test_losses, valid_loss_history, trainer
def test_mdnrnn_simulate_world(self): num_epochs = 300 num_episodes = 400 batch_size = 200 action_dim = 2 seq_len = 5 state_dim = 2 simulated_num_gaussians = 2 mdrnn_num_gaussians = 2 simulated_num_hidden_layers = 1 simulated_num_hiddens = 3 mdnrnn_num_hidden_layers = 1 mdnrnn_num_hiddens = 10 adam_lr = 0.01 replay_buffer = MDNRNNMemoryPool(max_replay_memory_size=num_episodes) swm = SimulatedWorldModel( action_dim=action_dim, state_dim=state_dim, num_gaussians=simulated_num_gaussians, lstm_num_hidden_layers=simulated_num_hidden_layers, lstm_num_hiddens=simulated_num_hiddens, ) possible_actions = torch.eye(action_dim) for _ in range(num_episodes): cur_state_mem = np.zeros((seq_len, state_dim)) next_state_mem = np.zeros((seq_len, state_dim)) action_mem = np.zeros((seq_len, action_dim)) reward_mem = np.zeros(seq_len) not_terminal_mem = np.zeros(seq_len) next_mus_mem = np.zeros( (seq_len, simulated_num_gaussians, state_dim)) swm.init_hidden(batch_size=1) next_state = torch.randn((1, 1, state_dim)) for s in range(seq_len): cur_state = next_state action = possible_actions[np.random.randint(action_dim)].view( 1, 1, action_dim) next_mus, reward = swm(action, cur_state) not_terminal = 1 if s == seq_len - 1: not_terminal = 0 # randomly draw for next state next_pi = torch.ones( simulated_num_gaussians) / simulated_num_gaussians index = Categorical(next_pi).sample((1, )).long().item() next_state = next_mus[0, 0, index].view(1, 1, state_dim) cur_state_mem[s] = cur_state.detach().numpy() action_mem[s] = action.numpy() reward_mem[s] = reward.detach().numpy() not_terminal_mem[s] = not_terminal next_state_mem[s] = next_state.detach().numpy() next_mus_mem[s] = next_mus.detach().numpy() replay_buffer.insert_into_memory(cur_state_mem, action_mem, next_state_mem, reward_mem, not_terminal_mem) num_batch = num_episodes // batch_size mdnrnn_params = MDNRNNParameters( hidden_size=mdnrnn_num_hiddens, num_hidden_layers=mdnrnn_num_hidden_layers, minibatch_size=batch_size, learning_rate=adam_lr, num_gaussians=mdrnn_num_gaussians, ) mdnrnn_net = MemoryNetwork( state_dim=state_dim, action_dim=action_dim, num_hiddens=mdnrnn_params.hidden_size, num_hidden_layers=mdnrnn_params.num_hidden_layers, num_gaussians=mdnrnn_params.num_gaussians, ) trainer = MDNRNNTrainer(mdnrnn_network=mdnrnn_net, params=mdnrnn_params, cum_loss_hist=num_batch) for e in range(num_epochs): for i in range(num_batch): training_batch = replay_buffer.sample_memories(batch_size) losses = trainer.train(training_batch) logger.info( "{}-th epoch, {}-th minibatch: \n" "loss={}, bce={}, gmm={}, mse={} \n" "cum loss={}, cum bce={}, cum gmm={}, cum mse={}\n".format( e, i, losses["loss"], losses["bce"], losses["gmm"], losses["mse"], np.mean(trainer.cum_loss), np.mean(trainer.cum_bce), np.mean(trainer.cum_gmm), np.mean(trainer.cum_mse), )) if (np.mean(trainer.cum_loss) < 0 and np.mean(trainer.cum_gmm) < -3.0 and np.mean(trainer.cum_bce) < 0.6 and np.mean(trainer.cum_mse) < 0.2): return assert False, "losses not reduced significantly during training"