Beispiel #1
0
def test_butterfly_fft():
    # DFT matrix for n = 4
    size = 4
    DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
    P = real_to_complex(
        torch.tensor([[1., 0., 0., 0.], [0., 0., 1., 0.], [0., 1., 0., 0.],
                      [0., 0., 0., 1.]]))
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=torch.tensor(
                       [[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]],
                       requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]],
                                        requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]],
                                          requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor(
                       [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]],
                       requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                        requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                          requires_grad=True))
    assert torch.allclose(
        complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
    br_perm = torch.tensor(bitreversal_permutation(size))
    assert torch.allclose(
        complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
    D = complex_matmul(DFT, P.transpose(0, 1))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
Beispiel #2
0
def test_blockpermproduct():
    size = 8
    input = torch.randn(3, size, 2)
    perm = BlockPermProduct(size, complex=True, share_logit=True)
    perm.logit[0] = float('inf')
    from utils import bitreversal_permutation
    assert torch.allclose(perm(input), input[:, bitreversal_permutation(size)])
Beispiel #3
0
 def __init__(self, size, complex=False, decreasing_size=True):
     super().__init__()
     m = int(math.log2(size))
     assert size == 1 << m, "size must be a power of 2"
     self.size = size
     self.complex = complex
     sizes = [size >> i for i in range(m)
              ] if decreasing_size else [size >> i for i in range(m)[::-1]]
     self.factors = nn.ModuleList(
         [Block2x2DiagBmm(size_, complex=complex) for size_ in sizes])
     self.br_perm = bitreversal_permutation(size)
Beispiel #4
0
def test_butterfly_dct():
    from scipy.fftpack import dct
    # DCT matrix for n = 4
    size = 4
    # Need to transpose as dct acts on rows of matrix np.eye, not columns
    DCT = torch.tensor(dct(np.eye(size)).T, dtype=torch.float)
    M0diag = torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]],
                          requires_grad=True)
    M0subdiag = torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True)
    M0superdiag = torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True)
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=M0diag,
                   subdiag=M0subdiag,
                   superdiag=M0superdiag)
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor(
                       [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]],
                       requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                        requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]],
                                          requires_grad=True))
    arange_ = np.arange(size)
    dct_perm = np.concatenate((arange_[::2], arange_[::-2]))
    br_perm = bitreversal_permutation(size)
    perm = torch.arange(size)[dct_perm][br_perm]
    arange_ = torch.arange(size, dtype=torch.float)
    diag_real = 2 * torch.cos(-math.pi * arange_ / (2 * size))
    diag_imag = 2 * torch.sin(-math.pi * arange_ / (2 * size))
    diag = torch.stack((torch.diag(diag_real), torch.diag(diag_imag)), dim=-1)
    assert torch.allclose(
        complex_matmul(diag, complex_matmul(M0.matrix(), M1.matrix()))[:, perm,
                                                                       0], DCT)
    D = torch.stack((diag_real, diag_imag), dim=-1)
    DM0 = Butterfly(size,
                    diagonal=2,
                    complex=True,
                    diag=complex_mul(D, M0diag),
                    subdiag=complex_mul(D[2:], M0subdiag),
                    superdiag=complex_mul(D[:2], M0superdiag))
    assert torch.allclose(
        complex_matmul(DM0.matrix(), M1.matrix())[:, perm, 0], DCT)
Beispiel #5
0
 def __init__(self, input_size):
     super().__init__()
     filter_size = \
         max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil()))
     self.butterfly_ifft = Butterfly(filter_size,
                                     filter_size,
                                     complex=True,
                                     tied_weight=False,
                                     bias=False)  # iFFT
     self.butterfly_ifft.twiddle = torch.nn.Parameter(
         fft_twiddle(filter_size, forward=False, normalized=True))
     self.butterfly_fft = Butterfly(filter_size,
                                    filter_size,
                                    complex=True,
                                    tied_weight=False,
                                    bias=False,
                                    increasing_stride=False)  # FFT
     self.butterfly_fft.twiddle = torch.nn.Parameter(
         butterfly_transpose_conjugate(self.butterfly_ifft.twiddle))
     f = fftfreq(filter_size)
     fourier_filter = self.create_filter(f)
     br = bitreversal_permutation(filter_size)
     self.fourier_filter_br = torch.nn.Parameter(fourier_filter[br])
Beispiel #6
0
def named_target_matrix(name, size):
    """
    Parameter:
        name: name of the target matrix
    Return:
        target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices.
    """
    if name == 'dft':
        return LA.dft(size, scale='sqrtn')[:, :, None].view('float64')
    elif name == 'idft':
        return np.ascontiguousarray(LA.dft(size, scale='sqrtn').conj().T)[:, :, None].view('float64')
    elif name == 'dft2':
        size_sr = int(math.sqrt(size))
        matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2)
        # matrix1d = LA.dft(size_sr, scale='sqrtn')
        # assert np.allclose(np.kron(m1d, m1d), matrix)
        # return matrix[:, :, None].view('float64')
        from butterfly.utils import bitreversal_permutation
        br_perm = bitreversal_permutation(size_sr)
        br_perm2 = np.arange(size_sr**2).reshape(size_sr, size_sr)[br_perm][:, br_perm].reshape(-1)
        matrix = np.ascontiguousarray(matrix[:, br_perm2])
        return matrix[:, :, None].view('float64')
    elif name == 'dct':
        # Need to transpose as dct acts on rows of matrix np.eye, not columns
        # return dct(np.eye(size), norm='ortho').T
        return dct(np.eye(size)).T / math.sqrt(size)
    elif name == 'dst':
        return dst(np.eye(size)).T / math.sqrt(size)
    elif name == 'hadamard':
        return LA.hadamard(size) / math.sqrt(size)
    elif name == 'hadamard2':
        size_sr = int(math.sqrt(size))
        matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr)
        return np.kron(matrix1d, matrix1d)
    elif name == 'b2':
        size_sr = int(math.sqrt(size))
        from butterfly import Block2x2DiagProduct
        b = Block2x2DiagProduct(size_sr)
        matrix1d = b(torch.eye(size_sr)).t().detach().numpy()
        return np.kron(matrix1d, matrix1d)
    elif name == 'convolution':
        np.random.seed(0)
        x = np.random.randn(size)
        return LA.circulant(x) / math.sqrt(size)
    elif name == 'hartley':
        return hartley_matrix(size) / math.sqrt(size)
    elif name == 'haar':
        return haar_matrix(size, normalized=True) / math.sqrt(size)
    elif name == 'legendre':
        grid = np.linspace(-1, 1, size + 2)[1:-1]
        return legendre.legvander(grid, size - 1).T / math.sqrt(size)
    elif name == 'hilbert':
        H = hilbert_matrix(size)
        return H / np.linalg.norm(H, 2)
    elif name == 'randn':
        np.random.seed(0)
        return np.random.randn(size, size) / math.sqrt(size)
    elif name == 'permutation':
        np.random.seed(0)
        perm = np.random.permutation(size)
        P = np.eye(size)[perm]
        return P
    elif name.startswith('rank-unnorm'):
        r = int(name[11:])
        np.random.seed(0)
        G = np.random.randn(size, r)
        H = np.random.randn(size, r)
        M = G @ H.T
        # M /= math.sqrt(size*r)
        return M
    elif name.startswith('rank'):
        r = int(name[4:])
        np.random.seed(0)
        G = np.random.randn(size, r)
        H = np.random.randn(size, r)
        M = G @ H.T
        M /= math.sqrt(size*r)
        return M
    elif name.startswith('sparse'):
        s = int(name[6:])
        # 2rn parameters
        np.random.seed(0)
        mask = sparse.random(size, size, density=s/size, data_rvs=np.ones)
        M = np.random.randn(size, size) * (mask.toarray())
        M /= math.sqrt(s)
        return M
    elif name.startswith('toeplitz'):
        r = int(name[8:])
        G = np.random.randn(size, r) / math.sqrt(size*r)
        H = np.random.randn(size, r) / math.sqrt(size*r)
        M = toeplitz_like(G, H)
        return M
    elif name == 'fastfood':
        n = size
        S = np.random.randn(n)
        G = np.random.randn(n)
        B = np.random.randn(n)
        # P = np.arange(n)
        P = np.random.permutation(n)
        H = hadamard(n)
        # SHGPHB
        # print(H)
        # print((H*B)[P,:])
        # print((H @ (G[:,np.newaxis] * (H * B)[P,:])))
        F = S[:,np.newaxis] * (H @ (G[:,np.newaxis] * (H * B)[P,:])) / n
        return F
        # x = np.random.randn(batch_size,n)
        # HB = hadamard_transform(B)
        # PHBx = HBx[:, P]
        # HGPHBx = hadamard_transform(G*PHBx)
        # return S*HGPHBx
    elif name == 'butterfly':
        # n (log n+1) params in the hierarchy
        b = Butterfly(in_size=size, out_size=size, bias=False, tied_weight=False, param='odo', nblocks=0)
        M = b(torch.eye(size))
        return M.cpu().detach().numpy()

    else:
        assert False, 'Target matrix name not recognized or implemented'
Beispiel #7
0
    def _setup(self, config):
        device = config['device']
        self.device = device
        size = config['size']
        if isinstance(config['target_matrix'], str):
            self.target_matrix = torch.tensor(named_target_matrix(
                config['target_matrix'], size),
                                              dtype=torch.float).to(device)
        else:
            self.target_matrix = torch.tensor(config['target_matrix'],
                                              dtype=torch.float).to(device)
        assert self.target_matrix.shape[0] == self.target_matrix.shape[
            1], 'Only square matrices are supported'
        assert self.target_matrix.dim() in [
            2, 3
        ], 'target matrix must be 2D if real of 3D if complex'
        complex = self.target_matrix.dim() == 3 or config['complex']
        torch.manual_seed(config['seed'])
        if config['model'] == 'B':
            self.model = nn.Sequential(
                FixedPermutation(torch.tensor(bitreversal_permutation(size))),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          ortho_init=True)).to(device)
        elif config['model'] == 'BP':
            self.model = nn.Sequential(
                Permutation(size=size, share_logit=config['share_logit'][0]),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          ortho_init=True)).to(device)
        elif config['model'] == 'PBT':
            self.model = nn.Sequential(
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          increasing_stride=False,
                          ortho_init=True),
                Permutation(size=size,
                            share_logit=config['share_logit'][0])).to(device)
        elif config['model'] == 'BPP':
            self.model = nn.Sequential(
                PermutationFactor(size=size),
                Permutation(size=size, share_logit=config['share_logit'][0]),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          ortho_init=True)).to(device)
        elif config['model'] == 'BPBP':
            self.model = nn.Sequential(
                Permutation(size=size, share_logit=config['share_logit'][0]),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          ortho_init=True),
                Permutation(size=size, share_logit=config['share_logit'][1]),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          ortho_init=True)).to(device)
        elif config['model'] == 'BBT':
            # param_type = 'regular' if complex else 'perm'
            param_type = config['param']
            self.model = nn.Sequential(
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          param=param_type,
                          increasing_stride=False),
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          param=param_type,
                          increasing_stride=True))
        elif config['model'][0] == 'T' and (config['model'][1:]).isdigit():
            depth = int(config['model'][1:])
            param_type = config['param']
            self.model = nn.Sequential(*[
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          param=param_type,
                          increasing_stride=False) for _ in range(depth)
            ])
        elif config['model'][0:3] == 'BBT' and (config['model'][3:]).isdigit():
            depth = int(config['model'][3:])
            param_type = config['param']
            self.model = nn.Sequential(*[
                nn.Sequential(
                    Butterfly(in_size=size,
                              out_size=size,
                              bias=False,
                              complex=complex,
                              param=param_type,
                              increasing_stride=False),
                    Butterfly(in_size=size,
                              out_size=size,
                              bias=False,
                              complex=complex,
                              param=param_type,
                              increasing_stride=True)) for _ in range(depth)
            ])
        elif config['model'][0] == 'B' and (config['model'][1:]).isdigit():
            depth = int(config['model'][1:])
            param_type = config['param']
            self.model = nn.Sequential(*[
                Butterfly(in_size=size,
                          out_size=size,
                          bias=False,
                          complex=complex,
                          param=param_type,
                          increasing_stride=True) for _ in range(depth)
            ])
        elif config['model'] == 'butterfly':
            # e = int(config['model'][4:])
            self.model = Butterfly(in_size=size,
                                   out_size=size,
                                   complex=complex,
                                   **config['bfargs'])
        # elif config['model'][0:3] == 'ODO':
        #     if (config['model'][3:]).isdigit():
        #         width = int(config['model'][3:])
        #         self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=0, expansion=width, diag_init='normal')
        #     elif config['model'][3] == 'k':
        #         k = int(config['model'][4:])
        #         self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=k, diag_init='normal')

        # non-butterfly transforms
        # elif config['model'][0:2] == 'TL' and (config['model'][2:]).isdigit():
        #     rank = int(config['model'][2:])
        elif config['model'][0:4] == 'rank' and (
                config['model'][4:]).isdigit():
            rank = int(config['model'][4:])
            self.model = nn.Sequential(
                nn.Linear(size, rank, bias=False),
                nn.Linear(rank, size, bias=False),
            )

        else:
            assert False, f"Model {config['model']} not implemented"

        self.nparameters = sum(param.nelement()
                               for param in self.model.parameters())
        print("Parameters: ", self.nparameters)

        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.n_epochs_per_validation = config['n_epochs_per_validation']
        self.input = torch.eye(size).to(device)
        if complex:
            self.input = real_to_complex(self.input)
Beispiel #8
0
def named_target_matrix(name, size):
    """
    Parameter:
        name: name of the target matrix
    Return:
        target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices.
    """
    if name == 'dft':
        return LA.dft(size, scale='sqrtn')[:, :, None].view('float64')
    elif name == 'idft':
        return np.ascontiguousarray(LA.dft(
            size, scale='sqrtn').conj().T)[:, :, None].view('float64')
    elif name == 'dft2':
        size_sr = int(math.sqrt(size))
        matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr),
                             norm='ortho').reshape(-1, size_sr**2)
        # matrix1d = LA.dft(size_sr, scale='sqrtn')
        # assert np.allclose(np.kron(m1d, m1d), matrix)
        # return matrix[:, :, None].view('float64')
        from butterfly.utils import bitreversal_permutation
        br_perm = bitreversal_permutation(size_sr)
        br_perm2 = np.arange(size_sr**2).reshape(
            size_sr, size_sr)[br_perm][:, br_perm].reshape(-1)
        matrix = np.ascontiguousarray(matrix[:, br_perm2])
        return matrix[:, :, None].view('float64')
    elif name == 'dct':
        # Need to transpose as dct acts on rows of matrix np.eye, not columns
        # return dct(np.eye(size), norm='ortho').T
        return dct(np.eye(size)).T / math.sqrt(size)
    elif name == 'dst':
        return dst(np.eye(size)).T / math.sqrt(size)
    elif name == 'hadamard':
        return LA.hadamard(size) / math.sqrt(size)
    elif name == 'hadamard2':
        size_sr = int(math.sqrt(size))
        matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr)
        return np.kron(matrix1d, matrix1d)
    elif name == 'b2':
        size_sr = int(math.sqrt(size))
        import torch
        from butterfly import Block2x2DiagProduct
        b = Block2x2DiagProduct(size_sr)
        matrix1d = b(torch.eye(size_sr)).t().detach().numpy()
        return np.kron(matrix1d, matrix1d)
    elif name == 'convolution':
        np.random.seed(0)
        x = np.random.randn(size)
        return LA.circulant(x) / math.sqrt(size)
    elif name == 'hartley':
        return hartley_matrix(size) / math.sqrt(size)
    elif name == 'haar':
        return haar_matrix(size, normalized=True) / math.sqrt(size)
    elif name == 'legendre':
        grid = np.linspace(-1, 1, size + 2)[1:-1]
        return legendre.legvander(grid, size - 1).T / math.sqrt(size)
    elif name == 'hilbert':
        H = hilbert_matrix(size)
        return H / np.linalg.norm(H, 2)
    elif name == 'randn':
        np.random.seed(0)
        return np.random.randn(size, size) / math.sqrt(size)
    else:
        assert False, 'Target matrix name not recognized or implemented'
Beispiel #9
0
 def _setup(self, config):
     device = config['device']
     self.device = device
     size = config['size']
     if isinstance(config['target_matrix'], str):
         self.target_matrix = torch.tensor(named_target_matrix(
             config['target_matrix'], size),
                                           dtype=torch.float).to(device)
     else:
         self.target_matrix = torch.tensor(config['target_matrix'],
                                           dtype=torch.float).to(device)
     assert self.target_matrix.shape[0] == self.target_matrix.shape[
         1], 'Only square matrices are supported'
     assert self.target_matrix.dim() in [
         2, 3
     ], 'target matrix must be 2D if real of 3D if complex'
     complex = self.target_matrix.dim() == 3 or config['complex']
     torch.manual_seed(config['seed'])
     if config['model'] == 'B':
         self.model = nn.Sequential(
             FixedPermutation(torch.tensor(bitreversal_permutation(size))),
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       ortho_init=True)).to(device)
     elif config['model'] == 'BP':
         self.model = nn.Sequential(
             Permutation(size=size, share_logit=config['share_logit'][0]),
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       ortho_init=True)).to(device)
     elif config['model'] == 'PBT':
         self.model = nn.Sequential(
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       increasing_stride=False,
                       ortho_init=True),
             Permutation(size=size,
                         share_logit=config['share_logit'][0])).to(device)
     elif config['model'] == 'BPP':
         self.model = nn.Sequential(
             PermutationFactor(size=size),
             Permutation(size=size, share_logit=config['share_logit'][0]),
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       ortho_init=True)).to(device)
     elif config['model'] == 'BPBP':
         self.model = nn.Sequential(
             Permutation(size=size, share_logit=config['share_logit'][0]),
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       ortho_init=True),
             Permutation(size=size, share_logit=config['share_logit'][1]),
             Butterfly(in_size=size,
                       out_size=size,
                       bias=False,
                       complex=complex,
                       ortho_init=True)).to(device)
     else:
         assert False, f'Model {model} not implemented'
     self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
     self.n_steps_per_epoch = config['n_steps_per_epoch']
     self.n_epochs_per_validation = config['n_epochs_per_validation']
     self.input = torch.eye(size).to(device)
     if complex:
         self.input = real_to_complex(self.input)