class LSTMPolicy(ModuleContainer): def __init__(self, input_dim, hidden_dim, num_layers, output_dim): super().__init__() self.hidden_dim = hidden_dim self.input_dim = input_dim self.num_layers = num_layers self.output_dim = output_dim self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) self.rnn = RNN(self.lstm, hidden_dim) self.softmax = nn.Softmax() self.linear = nn.Linear(hidden_dim, output_dim) self.modules = [self.rnn, self.linear] def reset(self, bs): self.rnn.init_hidden(bs) def forward(self, x): if len(x.size()) == 2: x = x.unsqueeze(0) lstm_out = self.rnn.forward(x) lstm_reshape = lstm_out.view((-1, self.hidden_dim)) output = self.softmax(self.linear(lstm_reshape)) dist = Categorical(output) return dist def set_state(self, state): self.rnn.set_state(state) def get_state(self): return self.rnn.get_state() def recurrent(self): return True
class GaussianLSTMPolicy(ModuleContainer): def __init__(self, input_dim, hidden_dim, num_layers, output_dim, log_var_network, init=xavier_init, scale_final=False, min_var=1e-4, obs_filter=None): super().__init__() self.hidden_dim = hidden_dim self.input_dim = input_dim self.num_layers = num_layers self.output_dim = output_dim self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers) self.rnn = RNN(self.lstm, hidden_dim) self.softmax = nn.Softmax() self.linear = nn.Linear(hidden_dim, output_dim) self.log_var_network = log_var_network self.modules = [self.rnn, self.linear, self.log_var_network] self.obs_filter = obs_filter self.min_log_var = np_to_var( np.log(np.array([min_var])).astype(np.float32)) self.apply(init) # self.apply(weights_init_mlp) if scale_final: if hasattr(self.mean_network, 'network'): self.mean_network.network.finallayer.weight.data.mul_(0.01) def forward(self, x): if self.obs_filter is not None: x.data = self.obs_filter(x.data) if len(x.size()) == 2: x = x.unsqueeze(0) lstm_out = self.rnn.forward(x) lstm_reshape = lstm_out.view((-1, self.hidden_dim)) mean = self.softmax(self.linear(lstm_reshape)) log_var = self.log_var_network(x.contiguous().view((-1, x.shape[-1]))) log_var = torch.max(self.min_log_var, log_var) # TODO Limit log var dist = Normal(mean=mean, log_var=log_var) return dist def reset(self, bs): self.rnn.init_hidden(bs) def set_state(self, state): self.rnn.set_state(state) def get_state(self): return self.rnn.get_state() def recurrent(self): return True