Пример #1
0
 def _setup(self, config):
     self.target_matrix = torch.tensor(config['target_matrix'],
                                       dtype=torch.float)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     size = self.target_matrix.shape[0]
     torch.manual_seed(config['seed'])
     # Transposing the permutation product won't capture the FFT, since we'll
     # permutations that interleave the first half and second half (inverse
     # of the permutation that separates the even and the odd).
     # However, using the permutation product with increasing size will work
     # since it can represent bit reversal, which is its own inverse.
     self.model = nn.Sequential(
         Block2x2DiagProduct(size=size, complex=True,
                             decreasing_size=False),
         BlockPermProduct(size=size,
                          complex=True,
                          share_logit=False,
                          increasing_size=True),
     )
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = real_to_complex(torch.eye(size))
Пример #2
0
 def _setup(self, config):
     size = config['size']
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=size,
                                   complex=True,
                                   fixed_order=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
     self.br_perm = torch.tensor(bitreversal_permutation(size))
Пример #3
0
 def _setup(self, config):
     torch.manual_seed(config['seed'])
     self.model = ButterflyProduct(size=config['size'],
                                   complex=True,
                                   fixed_order=config['fixed_order'],
                                   softmax_fn=config['softmax_fn'],
                                   learn_perm=True)
     if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
         self.semantic_loss_weight = config['semantic_loss_weight']
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     size = config['size']
     self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
Пример #4
0
 def _setup(self, config):
     self.target_matrix = torch.tensor(config['target_matrix'],
                                       dtype=torch.float)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     size = self.target_matrix.shape[0]
     torch.manual_seed(config['seed'])
     self.model = Block2x2DiagProduct(size=size, complex=True)
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = real_to_complex(
         torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])