Пример #1
0
 def forward_with_rotation_matrices_mask(self, xs, hat_xs):
     """Forward errors with rotation matrices"""
     N = xs.shape[0]
     masks = xs[:, :, 3].unsqueeze(1)
     masks = torch.nn.functional.conv1d(
         masks, self.weight, bias=None,
         stride=self.min_train_freq).double().transpose(1, 2)
     masks[masks < 1] = 0
     Xs = SO3.exp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double())
     hat_xs = self.dt * hat_xs.reshape(-1, 3).double()
     Omegas = SO3.exp(hat_xs[:, :3])
     # compute increment at min_train_freq by decimation
     for k in range(self.min_N):
         Omegas = Omegas[::2].bmm(Omegas[1::2])
     rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
     loss = self.f_huber(rs)
     # compute increment from min_train_freq to max_train_freq
     for k in range(self.min_N, self.max_N):
         Omegas = Omegas[::2].bmm(Omegas[1::2])
         Xs = Xs[::2].bmm(Xs[1::2])
         masks = masks[:, ::2] * masks[:, 1::2]
         rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
         rs = rs[masks[:, self.N0:].squeeze(2) == 1]
         loss = loss + self.f_huber(rs[:, 2]) / (2**(k - self.min_N + 1))
     return loss
Пример #2
0
 def forward_with_rotation_matrices(self, xs, hat_xs):
     """Forward errors with rotation matrices"""
     N = xs.shape[0]
     Xs = SO3.exp(xs[:, ::self.min_train_freq].reshape(-1, 3).double())
     hat_xs = self.dt * hat_xs.reshape(-1, 3).double()
     Omegas = SO3.exp(hat_xs[:, :3])
     # compute increment at min_train_freq by decimation
     for k in range(self.min_N):
         Omegas = Omegas[::2].bmm(Omegas[1::2])
     rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
     loss = self.f_huber(rs)
     # compute increment from min_train_freq to max_train_freq
     for k in range(self.min_N, self.max_N):
         Omegas = Omegas[::2].bmm(Omegas[1::2])
         Xs = Xs[::2].bmm(Xs[1::2])
         rs = SO3.log(bmtm(Omegas, Xs)).reshape(N, -1, 3)[:, self.N0:]
         loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1))
     return loss