def _test_non_lstm(self, gpu, name): in_size = 2 out_size = 3 device = "cuda:{}".format(gpu) if gpu >= 0 else "cpu" seqs_x = [ torch.rand(4, in_size, device=device), torch.rand(1, in_size, device=device), torch.rand(3, in_size, device=device), ] seqs_x = torch.nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False) self.assertTrue(name in ("GRU", "RNN")) cls = getattr(nn, name) link = cls(num_layers=1, input_size=in_size, hidden_size=out_size) link.to(device) # Forward twice: with None and non-None random states y0, h0 = link(seqs_x, None) y1, h1 = link(seqs_x, h0) y0, _ = torch.nn.utils.rnn.pad_packed_sequence(y0, batch_first=True) y1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1, batch_first=True) self.assertEqual(h0.shape, (1, 3, out_size)) self.assertEqual(h1.shape, (1, 3, out_size)) self.assertEqual(y0.shape, (3, 4, out_size)) self.assertEqual(y1.shape, (3, 4, out_size)) # Masked at 0 rs0_mask0 = mask_recurrent_state_at(h0, 0) y1m0, _ = link(seqs_x, rs0_mask0) y1m0, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m0, batch_first=True) torch_assert_allclose(y1m0[0], y0[0]) torch_assert_allclose(y1m0[1], y1[1]) torch_assert_allclose(y1m0[2], y1[2]) # Masked at (1, 2) rs0_mask12 = mask_recurrent_state_at(h0, (1, 2)) y1m12, _ = link(seqs_x, rs0_mask12) y1m12, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m12, batch_first=True) torch_assert_allclose(y1m12[0], y1[0]) torch_assert_allclose(y1m12[1], y0[1]) torch_assert_allclose(y1m12[2], y0[2]) # Get at 1 and concat with None rs0_get1 = get_recurrent_state_at(h0, 1, detach=False) assert rs0_get1.requires_grad torch_assert_allclose(rs0_get1, h0[:, 1]) concat_rs_get1 = concatenate_recurrent_states([None, rs0_get1, None]) y1g1, _ = link(seqs_x, concat_rs_get1) y1g1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1g1, batch_first=True) torch_assert_allclose(y1g1[0], y0[0]) torch_assert_allclose(y1g1[1], y1[1]) torch_assert_allclose(y1g1[2], y0[2]) # Get at 1 with detach=True rs0_get1_detach = get_recurrent_state_at(h0, 1, detach=True) assert not rs0_get1_detach.requires_grad torch_assert_allclose(rs0_get1_detach, h0[:, 1])
def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i, (state, action, reward, next_state, done, reset) in enumerate( zip( # NOQA self.batch_last_state, self.batch_last_action, batch_reward, batch_obs, batch_done, batch_reset, ) ): if state is not None: assert action is not None transition = { "state": state, "action": action, "reward": reward, "next_state": next_state, "nonterminal": 0.0 if done else 1.0, } if self.recurrent: transition["recurrent_state"] = get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True ) transition["next_recurrent_state"] = get_recurrent_state_at( self.train_recurrent_states, i, detach=True ) self.batch_last_episode[i].append(transition) if done or reset: assert self.batch_last_episode[i] self.memory.append(self.batch_last_episode[i]) self.batch_last_episode[i] = [] self.batch_last_state[i] = None self.batch_last_action[i] = None self.train_prev_recurrent_states = None if self.recurrent: # Reset recurrent states when episodes end indices_that_ended = [ i for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) if done or reset ] if indices_that_ended: self.train_recurrent_states = mask_recurrent_state_at( self.train_recurrent_states, indices_that_ended ) self._update_if_dataset_is_ready()
def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset): assert not self.training if self.recurrent: # Reset recurrent states when episodes end indices_that_ended = [ i for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) if done or reset ] if indices_that_ended: self.test_recurrent_states = mask_recurrent_state_at( self.test_recurrent_states, indices_that_ended )
def _batch_reset_recurrent_states_when_episodes_end(batch_done, batch_reset, recurrent_states): """Reset recurrent states when episodes end. Args: batch_done (array-like of bool): True iff episodes are terminal. batch_reset (array-like of bool): True iff episodes will be reset. recurrent_states (object): Recurrent state. Returns: object: New recurrent states. """ indices_that_ended = [ i for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) if done or reset ] if indices_that_ended: return mask_recurrent_state_at(recurrent_states, indices_that_ended) else: return recurrent_states
def mask01_forward_twice(): _, rs = one_step_forward(par, x_t0, None) rs = mask_recurrent_state_at(rs, [0, 1]) return one_step_forward(par, x_t1, rs)