Exemple #1
0
    def forward(self,x):
        '''
        Implement the symmetries of the state
        :param x: nsamplesx12 vector, where the first (last) 6 components correspond to a single spin
        :return: <x|psi>
        '''
        x.to(self.device)

        # We start by implementing the c3v+I slicing
        # output -> inv nsamples*24x12
        inv = torch.zeros(size=(self.group_elements * x.shape[0], x.shape[1]))
        for i, T in enumerate(self.Ts):
            inv[i::self.group_elements, :] = x.mm(T)
        # We then apply the first convolutional network to each element
        for l in self.conv1:
            inv = torch.relu(l.forward(inv))
        inv=self.conv1_output.forward(inv)
        # Then we apply the rolling step
        #index = np.asarray([np.roll(np.arange(self.group_elements) + i * self.group_elements, j) for i in np.arange(x.shape[0]) for j in np.arange(self.group_elements)])
        #inv2 = inv[index,:].flatten(1)
        # Now we can get perform the dense operations
        #  We then apply the first convolutional network to each element
        #for l in self.dense:
        #    inv2 = torch.relu(l.forward(inv2))
        #inv = self.dense_output.forward(inv2)
        # Finally, we perform the pooling
        pool  = torch.nn.AvgPool1d(self.group_elements)
        value = pool.forward(inv.flatten().unsqueeze(0).unsqueeze(1)).squeeze()

        # We still have to define the sign before the value, because of the anticommutation rules
        state = x.cpu().detach().numpy()[0]
        hex_order = [0,1,3,5,4,2,6,7,9,11,10,8]
        # keep track of the current order
        order = state * np.arange(1, 13)
        # then, put it in the hex order
        state_in_hex = order[hex_order]
        # get the non-0 elements
        state_in_hex = state_in_hex[np.where(state_in_hex > 0.1)]
        # get the ordered indices
        order = np.argsort(state_in_hex)
        # create a permutation
        perm = Permutation(order)
        # print the sign
        sign = perm.signature()

        return sign*value