def test_from_rep(self, batch, tau0): rand_rep = lambda tau, batch: [ torch.rand(batch + (t, 2 * l + 1, 2)).double() for l, t in enumerate(tau) ] rep = rand_rep(tau0, batch) tau = SO3Tau.from_rep(rep) assert type(tau) == SO3Tau assert list(tau) == list(tau0)
def forward(self, rep): """ Linearly mix a represention. Parameters ---------- rep : :obj:`list` of :obj:`torch.Tensor` Representation to mix. Returns ------- rep : :obj:`list` of :obj:`torch.Tensor` Mixed representation. """ if SO3Tau.from_rep(rep) != self.tau_in: raise ValueError('Tau of input rep does not match initialized tau!' ' rep: {} tau: {}'.format(SO3Tau.from_rep(rep), self.tau_in)) return so3_torch.mix(self.weights, rep)