def observe(self, obs, reward, done, reset): if self.training: self.t += 1 assert self.last_state is not None assert self.last_action is not None # Add a transition to the replay buffer transition = { "state": self.last_state, "action": self.last_action, "reward": reward, "next_state": obs, "is_state_terminal": done, } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_prev_recurrent_states, 0, detach=True)) self.train_prev_recurrent_states = None transition["next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_recurrent_states, 0, detach=True)) self._send_to_learner(transition, stop_episode=done or reset) if (done or reset) and self.recurrent: self.train_prev_recurrent_states = None self.train_recurrent_states = None else: if (done or reset) and self.recurrent: self.test_recurrent_states = None
def _test_non_lstm(self, gpu, name): in_size = 2 out_size = 3 device = "cuda:{}".format(gpu) if gpu >= 0 else "cpu" seqs_x = [ torch.rand(4, in_size, device=device), torch.rand(1, in_size, device=device), torch.rand(3, in_size, device=device), ] seqs_x = torch.nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False) self.assertTrue(name in ("GRU", "RNN")) cls = getattr(nn, name) link = cls(num_layers=1, input_size=in_size, hidden_size=out_size) link.to(device) # Forward twice: with None and non-None random states y0, h0 = link(seqs_x, None) y1, h1 = link(seqs_x, h0) y0, _ = torch.nn.utils.rnn.pad_packed_sequence(y0, batch_first=True) y1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1, batch_first=True) self.assertEqual(h0.shape, (1, 3, out_size)) self.assertEqual(h1.shape, (1, 3, out_size)) self.assertEqual(y0.shape, (3, 4, out_size)) self.assertEqual(y1.shape, (3, 4, out_size)) # Masked at 0 rs0_mask0 = mask_recurrent_state_at(h0, 0) y1m0, _ = link(seqs_x, rs0_mask0) y1m0, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m0, batch_first=True) torch_assert_allclose(y1m0[0], y0[0]) torch_assert_allclose(y1m0[1], y1[1]) torch_assert_allclose(y1m0[2], y1[2]) # Masked at (1, 2) rs0_mask12 = mask_recurrent_state_at(h0, (1, 2)) y1m12, _ = link(seqs_x, rs0_mask12) y1m12, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m12, batch_first=True) torch_assert_allclose(y1m12[0], y1[0]) torch_assert_allclose(y1m12[1], y0[1]) torch_assert_allclose(y1m12[2], y0[2]) # Get at 1 and concat with None rs0_get1 = get_recurrent_state_at(h0, 1, detach=False) assert rs0_get1.requires_grad torch_assert_allclose(rs0_get1, h0[:, 1]) concat_rs_get1 = concatenate_recurrent_states([None, rs0_get1, None]) y1g1, _ = link(seqs_x, concat_rs_get1) y1g1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1g1, batch_first=True) torch_assert_allclose(y1g1[0], y0[0]) torch_assert_allclose(y1g1[1], y1[1]) torch_assert_allclose(y1g1[2], y0[2]) # Get at 1 with detach=True rs0_get1_detach = get_recurrent_state_at(h0, 1, detach=True) assert not rs0_get1_detach.requires_grad torch_assert_allclose(rs0_get1_detach, h0[:, 1])
def _batch_observe_train( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "feature": self.batch_h[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True)) transition[ "next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_recurrent_states, i, detach=True)) self.replay_buffer.append(env_id=i, **transition) self._backup_if_necessary(self.t, self.batch_h[i]) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, )
def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i, (state, action, reward, next_state, done, reset) in enumerate( zip( # NOQA self.batch_last_state, self.batch_last_action, batch_reward, batch_obs, batch_done, batch_reset, ) ): if state is not None: assert action is not None transition = { "state": state, "action": action, "reward": reward, "next_state": next_state, "nonterminal": 0.0 if done else 1.0, } if self.recurrent: transition["recurrent_state"] = get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True ) transition["next_recurrent_state"] = get_recurrent_state_at( self.train_recurrent_states, i, detach=True ) self.batch_last_episode[i].append(transition) if done or reset: assert self.batch_last_episode[i] self.memory.append(self.batch_last_episode[i]) self.batch_last_episode[i] = [] self.batch_last_state[i] = None self.batch_last_action[i] = None self.train_prev_recurrent_states = None if self.recurrent: # Reset recurrent states when episodes end indices_that_ended = [ i for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) if done or reset ] if indices_that_ended: self.train_recurrent_states = mask_recurrent_state_at( self.train_recurrent_states, indices_that_ended ) self._update_if_dataset_is_ready()
def get_and_concat_rs_forward(): _, rs = one_step_forward(par, x_t0, None) rs0 = get_recurrent_state_at(rs, 0, detach=True) rs1 = get_recurrent_state_at(rs, 1, detach=True) concat_rs = concatenate_recurrent_states([rs0, rs1]) return one_step_forward(par, x_t1, concat_rs)