示例#1
0
class GaussianBidirectionalNetwork(GaussianNetwork):
    def __init__(self, input_dim, hidden_dim, num_layers, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        # self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim,
                            hidden_dim,
                            num_layers,
                            bidirectional=True)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.modules.extend([self.rnn])
        # self.linear = nn.Linear(hidden_dim * 2, output_dim)
    def reset(self, x):
        self.rnn.init_hidden(x.size()[1])

    def forward(self, x):
        self.reset(x)
        lstm_out = self.rnn.forward(x)
        lstm_mean = torch.mean(lstm_out, dim=0)
        # output = self.linear(lstm_mean)
        mean, log_var = self.mean_network(lstm_mean), self.log_var_network(
            lstm_mean)
        dist = Normal(mean, log_var=log_var)
        return dist

    def recurrent(self):
        return True
示例#2
0
class LSTMPolicy(ModuleContainer):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)
        self.modules = [self.rnn, self.linear]

    def reset(self, bs):
        self.rnn.init_hidden(bs)

    def forward(self, x):
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        lstm_out = self.rnn.forward(x)
        lstm_reshape = lstm_out.view((-1, self.hidden_dim))
        output = self.softmax(self.linear(lstm_reshape))
        dist = Categorical(output)
        return dist

    def set_state(self, state):
        self.rnn.set_state(state)

    def get_state(self):
        return self.rnn.get_state()

    def recurrent(self):
        return True
示例#3
0
class GaussianLSTMPolicy(ModuleContainer):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 num_layers,
                 output_dim,
                 log_var_network,
                 init=xavier_init,
                 scale_final=False,
                 min_var=1e-4,
                 obs_filter=None):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)

        self.log_var_network = log_var_network
        self.modules = [self.rnn, self.linear, self.log_var_network]

        self.obs_filter = obs_filter
        self.min_log_var = np_to_var(
            np.log(np.array([min_var])).astype(np.float32))

        self.apply(init)
        # self.apply(weights_init_mlp)
        if scale_final:
            if hasattr(self.mean_network, 'network'):
                self.mean_network.network.finallayer.weight.data.mul_(0.01)

    def forward(self, x):
        if self.obs_filter is not None:
            x.data = self.obs_filter(x.data)
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        lstm_out = self.rnn.forward(x)
        lstm_reshape = lstm_out.view((-1, self.hidden_dim))
        mean = self.softmax(self.linear(lstm_reshape))

        log_var = self.log_var_network(x.contiguous().view((-1, x.shape[-1])))
        log_var = torch.max(self.min_log_var, log_var)
        # TODO Limit log var
        dist = Normal(mean=mean, log_var=log_var)
        return dist

    def reset(self, bs):
        self.rnn.init_hidden(bs)

    def set_state(self, state):
        self.rnn.set_state(state)

    def get_state(self):
        return self.rnn.get_state()

    def recurrent(self):
        return True