def forward_with_quaternion_mask(self, xs, hat_xs): """Forward errors with quaternion""" N = xs.shape[0] masks = xs[:, :, 3].unsqueeze(1) masks = torch.nn.functional.conv1d( masks, self.weight, bias=None, stride=self.min_train_freq).double().transpose(1, 2) masks[masks < 1] = 0 Xs = SO3.qexp(xs[:, ::self.min_train_freq, :3].reshape(-1, 3).double()) hat_xs = self.dt * hat_xs.reshape(-1, 3).double() Omegas = SO3.qexp(hat_xs[:, :3]) # compute increment at min_train_freq by decimation for k in range(self.min_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] rs = rs[masks[:, self.N0:].squeeze(2) == 1] loss = self.f_huber(rs) # compute increment from min_train_freq to max_train_freq for k in range(self.min_N, self.max_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) Xs = SO3.qmul(Xs[::2], Xs[1::2]) masks = masks[:, ::2] * masks[:, 1::2] rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] rs = rs[masks[:, self.N0:].squeeze(2) == 1] loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1)) return loss
def forward_with_quaternions(self, xs, hat_xs): """Forward errors with quaternion""" N = xs.shape[0] Xs = SO3.qexp(xs[:, ::self.min_train_freq].reshape(-1, 3).double()) hat_xs = self.dt * hat_xs.reshape(-1, 3).double() Omegas = SO3.qexp(hat_xs[:, :3]) # compute increment at min_train_freq by decimation for k in range(self.min_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)).reshape(N, -1, 3)[:, self.N0:] loss = self.f_huber(rs) # compute increment from min_train_freq to max_train_freq for k in range(self.min_N, self.max_N): Omegas = SO3.qmul(Omegas[::2], Omegas[1::2]) Xs = SO3.qmul(Xs[::2], Xs[1::2]) rs = SO3.qlog(SO3.qmul(SO3.qinv(Omegas), Xs)) rs = rs.view(N, -1, 3)[:, self.N0:] loss = loss + self.f_huber(rs) / (2**(k - self.min_N + 1)) return loss