def forward_with_quaternion_mask(self, xs, hat_xs): """Forward errors with quaternion""" N = xs.shape[0] masks = xs[:, :, 3].unsqueeze(1) masks = torch.nn.functional.conv1d( masks, self.weight, bias=None, stride=self.min_train_freq).double().transpose(1, 2) masks[masks < 1] = 0 Xs = SO3.qexp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double()) hat_xs = self.dt * hat_xs.reshape(-1, 3).double() Omegas = SO3.qexp(hat_xs[:, :3]) # compute increment at min_train_freq by decimation for k in range(self.min_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] rs = rs[masks[:, self.N0:].squeeze(2) == 1] loss = self.f_huber(rs) # compute increment from min_train_freq to max_train_freq for k in range(self.min_N, self.max_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) Xs = SO3.qmul(Xs[::2], Xs[1::2]) masks = masks[:, ::2] * masks[:, 1::2] rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] rs = rs[masks[:, self.N0:].squeeze(2) == 1] loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1)) return loss
def forward_with_quaternions(self, xs, hat_xs): """Forward errors with quaternion""" N = xs.shape[0] Xs = SO3.qexp(xs[:, ::self.min_train_freq].reshape(-1, 3).double()) hat_xs = self.dt * hat_xs.reshape(-1, 3).double() Omegas = SO3.qexp(hat_xs[:, :3]) # compute increment at min_train_freq by decimation for k in range(self.min_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] loss = self.f_huber(rs) # compute increment from min_train_freq to max_train_freq for k in range(self.min_N, self.max_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) Xs = SO3.qmul(Xs[::2], Xs[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)) rs = rs.view(N, -1, 3)[:, self.N0:] loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1)) return loss
def integrate_with_quaternions_superfast(self, N, raw_us, net_us): imu_qs = SO3.qnorm(SO3.qexp(raw_us[:, :3].cuda().double()*self.dt)) net_qs = SO3.qnorm(SO3.qexp(net_us[:, :3].cuda().double()*self.dt)) Rot0 = SO3.qnorm(self.gt['qs'][:2].cuda().double()) imu_qs[0] = Rot0[0] net_qs[0] = Rot0[0] N = np.log2(imu_qs.shape[0]) for i in range(int(N)): k = 2**i imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:-k], imu_qs[k:])) net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:-k], net_qs[k:])) if int(N) < N: k = 2**int(N) k2 = imu_qs[k:].shape[0] imu_qs[k:] = SO3.qnorm(SO3.qmul(imu_qs[:k2], imu_qs[k:])) net_qs[k:] = SO3.qnorm(SO3.qmul(net_qs[:k2], net_qs[k:])) imu_Rots = SO3.from_quaternion(imu_qs).float() net_Rots = SO3.from_quaternion(net_qs).float() return net_qs.cpu(), imu_Rots, net_Rots