def forward(self,x): ''' Implement the symmetries of the state :param x: nsamplesx12 vector, where the first (last) 6 components correspond to a single spin :return: <x|psi> ''' x.to(self.device) # We start by implementing the c3v+I slicing # output -> inv nsamples*24x12 inv = torch.zeros(size=(self.group_elements * x.shape[0], x.shape[1])) for i, T in enumerate(self.Ts): inv[i::self.group_elements, :] = x.mm(T) # We then apply the first convolutional network to each element for l in self.conv1: inv = torch.relu(l.forward(inv)) inv=self.conv1_output.forward(inv) # Then we apply the rolling step #index = np.asarray([np.roll(np.arange(self.group_elements) + i * self.group_elements, j) for i in np.arange(x.shape[0]) for j in np.arange(self.group_elements)]) #inv2 = inv[index,:].flatten(1) # Now we can get perform the dense operations # We then apply the first convolutional network to each element #for l in self.dense: # inv2 = torch.relu(l.forward(inv2)) #inv = self.dense_output.forward(inv2) # Finally, we perform the pooling pool = torch.nn.AvgPool1d(self.group_elements) value = pool.forward(inv.flatten().unsqueeze(0).unsqueeze(1)).squeeze() # We still have to define the sign before the value, because of the anticommutation rules state = x.cpu().detach().numpy()[0] hex_order = [0,1,3,5,4,2,6,7,9,11,10,8] # keep track of the current order order = state * np.arange(1, 13) # then, put it in the hex order state_in_hex = order[hex_order] # get the non-0 elements state_in_hex = state_in_hex[np.where(state_in_hex > 0.1)] # get the ordered indices order = np.argsort(state_in_hex) # create a permutation perm = Permutation(order) # print the sign sign = perm.signature() return sign*value