Example #1
0
 def forward_with_quaternion_mask(self, xs, hat_xs):
     """Forward errors with quaternion"""
     N = xs.shape[0]
     masks = xs[:, :, 3].unsqueeze(1)
     masks = torch.nn.functional.conv1d(
         masks, self.weight, bias=None,
         stride=self.min_train_freq).double().transpose(1, 2)
     masks[masks < 1] = 0
     Xs = SO3.qexp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double())
     hat_xs = self.dt * hat_xs.reshape(-1, 3).double()
     Omegas = SO3.qexp(hat_xs[:, :3])
     # compute increment at min_train_freq by decimation
     for k in range(self.min_N):
         Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
     rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1,
                                                           3)[:, self.N0:]
     rs = rs[masks[:, self.N0:].squeeze(2) == 1]
     loss = self.f_huber(rs)
     # compute increment from min_train_freq to max_train_freq
     for k in range(self.min_N, self.max_N):
         Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
         Xs = SO3.qmul(Xs[::2], Xs[1::2])
         masks = masks[:, ::2] * masks[:, 1::2]
         rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas),
                                Xs)).reshape(N, -1, 3)[:, self.N0:]
         rs = rs[masks[:, self.N0:].squeeze(2) == 1]
         loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1))
     return loss
Example #2
0
 def forward_with_quaternions(self, xs, hat_xs):
     """Forward errors with quaternion"""
     N = xs.shape[0]
     Xs = SO3.qexp(xs[:, ::self.min_train_freq].reshape(-1, 3).double())
     hat_xs = self.dt * hat_xs.reshape(-1, 3).double()
     Omegas = SO3.qexp(hat_xs[:, :3])
     # compute increment at min_train_freq by decimation
     for k in range(self.min_N):
         Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
     rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1,
                                                           3)[:, self.N0:]
     loss = self.f_huber(rs)
     # compute increment from min_train_freq to max_train_freq
     for k in range(self.min_N, self.max_N):
         Omegas = SO3.qmul(Omegas[::2], Omegas[1::2])
         Xs = SO3.qmul(Xs[::2], Xs[1::2])
         rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs))
         rs = rs.view(N, -1, 3)[:, self.N0:]
         loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1))
     return loss