def _setup(self, config): self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' size = self.target_matrix.shape[0] torch.manual_seed(config['seed']) # Transposing the permutation product won't capture the FFT, since we'll # permutations that interleave the first half and second half (inverse # of the permutation that separates the even and the odd). # However, using the permutation product with increasing size will work # since it can represent bit reversal, which is its own inverse. self.model = nn.Sequential( Block2x2DiagProduct(size=size, complex=True, decreasing_size=False), BlockPermProduct(size=size, complex=True, share_logit=False, increasing_size=True), ) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = real_to_complex(torch.eye(size))
def _setup(self, config): size = config['size'] torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=size, complex=True, fixed_order=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) self.br_perm = torch.tensor(bitreversal_permutation(size))
def _setup(self, config): torch.manual_seed(config['seed']) self.model = ButterflyProduct(size=config['size'], complex=True, fixed_order=config['fixed_order'], softmax_fn=config['softmax_fn'], learn_perm=True) if (not config['fixed_order']) and config['softmax_fn'] == 'softmax': self.semantic_loss_weight = config['semantic_loss_weight'] self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] size = config['size'] self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
def _setup(self, config): self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' size = self.target_matrix.shape[0] torch.manual_seed(config['seed']) self.model = Block2x2DiagProduct(size=size, complex=True) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = real_to_complex( torch.eye(size)[:, torch.tensor(bitreversal_permutation(size))])