Exemple #1
0
    def forward(
        self, env_state, torques
    ):  # env_state, shape: (batch, seq, env_state)|torques, shape: (batch, seq, torques)

        h_size = self.rnn.hidden_size
        # h_0, shape need to be: (num_layers * num_directions, batch_size, hidden_size)
        h_0 = self.encoder0(remove_torque(env_state))
        h_0 = torch.stack([
            h_0[:, i * h_size:(i + 1) * h_size]
            for i in range(self.rnn.num_layers)
        ],
                          dim=0)

        if self.base_rnn == 'lstm':
            # c_0, shape should be: (num_layers * num_directions, batch_size, hidden_size)
            c_0 = self.encoder1(remove_torque(env_state))
            c_0 = torch.stack([
                c_0[:, i * h_size:(i + 1) * h_size]
                for i in range(self.rnn.num_layers)
            ],
                              dim=0)
            output, (h_n, c_n) = self.rnn(torques, (h_0, c_0))

        else:  # for 'gru and rnn'
            output, h_n = self.rnn(torques, h_0)

        diffs = self.decoder(output)

        return diffs
Exemple #2
0
    def forward(self, x):
        x = remove_torque(x)
        for fc_name in self.fc_list[:-1]:
            fc = getattr(self, fc_name)
            x = F.relu(fc(x))
        fc = getattr(self, self.fc_list[-1])
        x = torch.tanh(fc(x)) # todo: better activation function?

        return x
Exemple #3
0
 def forward(self, x0, x1):
     partial_x1 = get_arm1_end_points(x1) - get_arm1_end_points(x0)
     x0_no_torque = remove_torque(x0)
     x = torch.cat((x0_no_torque, partial_x1), -1)
     x = self.encoder(x)
     x = self.stepper(x)
     x = self.decoder(x)
     x = torch.tanh(x)
     return x
Exemple #4
0
    def forward(self, s0, s_target, torque_predictors):
        input = torch.cat([remove_torque(s0), get_arm1_end_points(s_target)], -1)
        x = F.relu(self.fc0(input))
        x = self.fc1(x)
        x = torch.softmax(x, -1)

        torques = torch.stack([pr(s0, s_target) for pr in torque_predictors], -2)
        torques = torques * x.unsqueeze(-1)
        torques = torques.sum(-2)

        return torques, x
Exemple #5
0
 def forward(self, x):
     x = remove_torque(x)
     x = F.relu(self.fc0(x))
     x = F.relu(self.fc1(x))
     x = torch.tanh(self.fc2(x)) # todo: better activation function?
     return x
Exemple #6
0
 def forward(self, x):
     torques = get_torques(x)
     x = remove_torque(x)
     x = self.fc0(x)
     x = add_torque(x, torques)
     return x