예제 #1
0
def fastfood(diag1: torch.Tensor,
             diag2: torch.Tensor,
             diag3: torch.Tensor,
             permutation: torch.Tensor,
             normalized: bool = False,
             increasing_stride: bool = True,
             separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication:
    x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1,
    where H is the Hadamard matrix and P is a permutation matrix.
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,)
        diag3: (n,)
        permutation: (n,)
        normalized: if True, corresponds to the orthogonal Hadamard transform
                    (i.e. multiplied by 1/sqrt(n))
        increasing_stride: whether the first Butterfly in the sequence has increasing stride.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == diag3.shape == permutation.shape == (n, )
    h1 = hadamard(n, normalized, increasing_stride)
    h2 = hadamard(n, normalized, not increasing_stride)
    if not separate_diagonal:
        h1 = diagonal_butterfly(h1, diag1, diag_first=True)
        h2 = diagonal_butterfly(h2, diag2, diag_first=True)
        h2 = diagonal_butterfly(h2, diag3, diag_first=False)
        return nn.Sequential(h1, FixedPermutation(permutation), h2)
    else:
        return nn.Sequential(Diagonal(diagonal_init=diag1), h1,
                             FixedPermutation(permutation),
                             Diagonal(diagonal_init=diag2), h2,
                             Diagonal(diagonal_init=diag3))
예제 #2
0
def dct(n: int, type: int = 2, normalized: bool = False) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs the DCT.
    Parameters:
        n: size of the DCT. Must be a power of 2.
        type: either 2, 3, or  4. These are the only types supported. See scipy.fft.dct's notes.
        normalized: if True, corresponds to the orthogonal DCT (see scipy.fft.dct's notes)
    """
    assert type in [2, 3, 4]
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    if type in [2, 4]:
        b = fft(n, normalized=normalized, br_first=True, with_br_perm=False)
        if type == 4:
            even_mul = torch.exp(-1j * math.pi / (2 * n) *
                                 (torch.arange(0.0, n, 2) + 0.5))
            odd_mul = torch.exp(1j * math.pi / (2 * n) *
                                (torch.arange(1.0, n, 2) + 0.5))
            preprocess_diag = torch.stack((even_mul, odd_mul),
                                          dim=-1).flatten()
            # This proprocess_diag is before the permutation.
            # To move it after the permutation, we have to permute the diagonal
            b = diagonal_butterfly(b,
                                   preprocess_diag[perm[br]],
                                   diag_first=True)
        if normalized:
            if type in [2, 3]:
                postprocess_diag[0] /= 2.0
                postprocess_diag[1:] /= math.sqrt(2)
            elif type == 4:
                postprocess_diag /= math.sqrt(2)
        b = diagonal_butterfly(b, postprocess_diag, diag_first=False)
        return nn.Sequential(FixedPermutation(perm[br]), Real2Complex(), b,
                             Complex2Real())
    else:
        assert type == 3
        b = ifft(n, normalized=normalized, br_first=False, with_br_perm=False)
        postprocess_diag[0] /= 2.0
        if normalized:
            postprocess_diag[1:] /= math.sqrt(2)
        else:
            # We want iFFT with the scaling of 1.0 instead of 1 / n
            with torch.no_grad():
                b.twiddle *= 2
        b = diagonal_butterfly(b, postprocess_diag.conj(), diag_first=True)
        perm_inverse = invert(perm)
        return nn.Sequential(Real2Complex(), b, Complex2Real(),
                             FixedPermutation(br[perm_inverse]))
예제 #3
0
파일: xd.py 프로젝트: nick11roberts/XD
def fourier_diag(*nd,
                 diags=None,
                 inv=False,
                 diag_first=True,
                 with_br_perm=False,
                 **kwargs):
    '''returns n-dimensional FFT Butterfly matrix multiplied by a diagonal matrix
    Args:
        nd: input sizes of each dimension
        diags: torch.Tensor vectors specifying diagonals; if None uses the identity
        inv: return inverse FFT
        diag_first: returns FFT * diagonal; if False returns diagonal * FFT
        with_br_perm: uses bit-reversal permutation
        kwargs: passed to torch_butterfly.special.fft
    Returns:
        TensorProduct object
    '''

    kwargs['with_br_perm'] = with_br_perm

    if diags is None:
        if len(nd) == 1:
            return TensorProduct(
                ifft(*nd, **kwargs) if inv else fft(*nd, **kwargs))
        if len(nd) == 2:
            return TensorProduct(
                ifft2d(*nd, **kwargs) if inv else fft2d(*nd, **kwargs))
        diags = [torch.ones(n) for n in nd]

    assert set(kwargs.keys()).issubset({'normalized', 'br_first', 'with_br_perm'}) and not with_br_perm, \
            "invalid kwargs when using diags or >2 dims"
    func = ifft if inv else fft
    return TensorProduct(
        *(diagonal_butterfly(func(n, **kwargs), diag, diag_first=diag_first)
          for n, diag in zip(nd, diags)))
예제 #4
0
def hadamard_diagonal(diagonals: torch.Tensor,
                      normalized: bool = False,
                      increasing_stride: bool = True,
                      separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that performs multiplication by H D H D ... H D,
    where H is the Hadamard matrix and D is a diagonal matrix
    Parameters:
        diagonals: (k, n), where k is the number of diagonal matrices and n is the dimension of the
            Hadamard transform.
        normalized: if True, corresponds to the orthogonal Hadamard transform
                    (i.e. multiplied by 1/sqrt(n))
        increasing_stride: whether the returned Butterfly has increasing stride.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    k, n = diagonals.shape
    if not separate_diagonal:
        butterflies = []
        for i, diagonal in enumerate(diagonals.unbind()):
            cur_increasing_stride = increasing_stride != (i % 2 == 1)
            h = hadamard(n, normalized, cur_increasing_stride)
            butterflies.append(diagonal_butterfly(h, diagonal,
                                                  diag_first=True))
        return reduce(butterfly_product, butterflies)
    else:
        modules = []
        for i, diagonal in enumerate(diagonals.unbind()):
            modules.append(Diagonal(diagonal_init=diagonal))
            cur_increasing_stride = increasing_stride != (i % 2 == 1)
            h = hadamard(n, normalized, cur_increasing_stride)
            modules.append(h)
        return nn.Sequential(*modules)
예제 #5
0
파일: xd.py 프로젝트: nick11roberts/XD
    def __init__(self,
                 in_size,
                 in_channels,
                 out_channels,
                 arch_init='ortho',
                 weight_init=nn.init.kaiming_normal_,
                 kmatrix_depth=1,
                 base=2,
                 max_kernel_size=1,
                 padding=None,
                 stride=1,
                 arch_shape=None,
                 weight=None,
                 global_biasing='additive',
                 channel_gating='complete',
                 perturb=0.0,
                 crop_init=slice(0),
                 dilation_init=1,
                 padding_mode='circular',
                 bias=None,
                 checkpoint=False,
                 fourier_position=-1,
                 _swap=False):
        '''
        Args:
            in_size: input size
            in_channels: number of input channels
            out_channels: number of output_channels
            arch_init: 'ortho' or $OPTYPE (e.g. 'skip') or $OPTYPE'_'$KERNELSIZE (e.g. 'conv_3x3')
            weight_init: function that initializes weight tensor
            kmatrix_depth: depth of each kmatrix
            base: base to use for kmatrix (must be 2 or 4)
            max_kernel_size: maximum kernel size
            padding: determines padding; by default sets padding according to arch_init 
            stride: governs subsampling
            arch_shape: architecture that determines the output shape; uses arch_init by default
            weight: model weights
            global_biasing: 'additive' or 'interp' or False
            channel_gating: 'complete' or 'interp' or False
            perturb: scale of perturbation to arch params
            crop_init: input slice(s) to crop
            dilation_init: kernel dilation at initialization
            padding_mode: 'circular' or 'zeros'; for 'zeros' will adjust in_size as needed
            bias: optional bias parameter
            checkpoint: apply checkpointing to kmatrix operations
            fourier_position: where to put each Fourier matrix when warm starting; -1 applies it last
        '''

        if not _swap:
            # '_swap' variable allows for fast re-initialization of a module; useful for computing metrics
            super(XD, self).__init__()
            self._init_args = (in_size, in_channels, out_channels)
            self._init_kwargs = {
                'arch_shape': arch_init,
                'padding': padding,
                'crop_init': crop_init,
                'dilation_init': dilation_init,
                'padding_mode': padding_mode,
                'checkpoint': checkpoint,
                'fourier_position': fourier_position
            }
        assert base in {2, 4}, "'base' must be 2 or 4"
        assert global_biasing in {'additive', 'interp',
                                  False}, "invalid value for 'global_biasing'"
        assert channel_gating in {'complete', 'interp',
                                  False}, "invalid value for 'channel_gating'"

        self.checkpoint = checkpoint
        self.base = base
        self.chan = (out_channels, in_channels)
        self.depth = int2tuple(kmatrix_depth, length=3)
        self.dims = 2 if type(in_size) == int else len(in_size)
        in_size = int2tuple(in_size, length=self.dims)
        if padding_mode == 'zeros':
            # increases effective input size if required due to zero-padding
            padding = int2tuple(0 if padding is None else padding,
                                length=self.dims)
            in_size = tuple(n + 2 * p for n, p in zip(in_size, padding))
            self.zero_pad = tuple(sum(([p, p] for p in padding), []))
            padding = [0] * self.dims
        else:
            self.zero_pad = ()
        self.in_size = tuple(2**math.ceil(math.log2(n)) for n in in_size)
        crop_init = int2tuple(crop_init, length=self.dims)
        dilation_init = tuple(
            reversed(int2tuple(dilation_init, length=self.dims)))
        self.max_kernel_size, kd_init, skips, fourier_init, diagonal_init, self.unpadding = self._parse_init(
            arch_init, max_kernel_size, padding, arch_shape, dilation_init,
            _swap)
        zeroL = diagonal_init and global_biasing == 'additive'
        self.nd = tuple(reversed(self.in_size))
        self.kd = tuple(reversed(self.max_kernel_size))
        self.pd = tuple(k // 2 for k in self.kd)
        self.stride = int2tuple(stride, length=self.dims)
        if self.dims > 3:
            assert all(
                s == 1
                for s in self.stride), "must have stride 1 if using >3 dims"
            self.subsample = nn.Sequential(
            )  # TODO: handle stride>1 for >3 dimensional XD-op
        else:
            self.subsample = AvgPool(self.dims)(kernel_size=[1] * self.dims,
                                                stride=self.stride)

        if not _swap:
            self.weight = nn.Parameter(
                torch.Tensor(out_channels, in_channels, *self.max_kernel_size))
            weight_init(self.weight)
        if not weight is None:
            if type(weight
                    ) == nn.Parameter and self.weight.shape == weight.shape:
                self.weight = weight
            else:
                self._offset_insert(self.weight.data,
                                    weight.data.to(self.weight.device))
        self.bias = nn.Parameter(bias) if type(bias) == torch.Tensor else bias

        channels = min(self.chan)
        inoff, outoff = int(0.5 * (in_channels - channels)), int(
            0.5 * (out_channels - channels))
        if not _swap:
            self.register_buffer('diag', None, persistent=False)
            self.register_buffer('kron', None, persistent=False)
            self.register_buffer('_one', self.r2c(torch.ones(1)))
            self.register_buffer('_1', self.r2c(torch.ones(self.chan)))
            self.register_buffer('_I', self.r2c(torch.zeros(self.chan)))
            self._I[outoff:outoff + channels,
                    inoff:inoff + channels] = torch.eye(channels)

        for (kmatrix_name, diags), depth, fpos in zip(
            [
                ('K', [self.diag_K(n, s)
                       for n, s in zip(self.nd, skips)]),  # handles strides
                ('L', [
                    torch.zeros(n) if zeroL else self.diag_L(n, k)
                    for n, k in zip(self.nd, kd_init)
                ]),  # handles kernel size limits
                ('M', [self.diag_M(n, c) for n, c in zip(self.nd, crop_init)])
            ],  # handles input cropping
                self.depth,
                int2tuple(fourier_position, length=3)):
            if _swap:
                kmatrix = getattr(self, kmatrix_name)
            else:
                kmatrix_kwargs = {
                    'bias': False,
                    'increasing_stride': kmatrix_name == 'K',
                    'complex': True,
                    'init': 'identity' if fourier_init else arch_init,
                    'nblocks': depth,
                }
                kmatrix = TensorProduct(*(Butterfly(n, n, **kmatrix_kwargs)
                                          for n in self.nd))
            if fourier_init:
                fourier_kmatrix = self.get_fourier(
                    kmatrix_name,
                    *self.nd,
                    diags=[
                        self._perturb(
                            diag if d == 1 else torch.ones(diag.shape),
                            perturb) for d, diag in zip(dilation_init, diags)
                    ])
                if kmatrix_name == 'L' and any(d > 1 for d in dilation_init):
                    fpos = max(2, depth + fpos if fpos < 0 else fpos)
                for dim, d, k, n in zip(range(1, self.dims + 1), dilation_init,
                                        self.kd, self.nd):
                    if kmatrix_name == 'L' and d > 1:
                        # handles initialization of middle K-matrix for the case of dilated convs; requires kmatrix_depth >= 3
                        assert depth >= 3, "using dilation > 1 requires depth at least (1, 3, 1)"
                        kmatrix.getmap(
                            dim).twiddle.data[:, :2] = diagonal_butterfly(
                                perm2butterfly(self._atrous_permutation(
                                    n, k, d),
                                               complex=True),
                                diags[dim - 1],
                                diag_first=True).twiddle.data.to(
                                    kmatrix.device())
                    kmatrix.getmap(dim).twiddle.data[
                        0, fpos] = fourier_kmatrix.getmap(dim).twiddle.data[
                            0, 0].to(kmatrix.device())
            if base == 4:
                for dim in range(1, self.dims + 1):
                    kmatrix.setmap(dim, kmatrix.getmap(dim).to_base4())
            setattr(self, kmatrix_name, kmatrix)

        self.global_biasing = global_biasing
        filt = self._offset_insert(
            torch.zeros(1, 1, *self.max_kernel_size),
            torch.ones(1, 1, *kd_init) / np.prod(kd_init)
            if 'pool' in arch_init else torch.ones(1, 1, *[1] * self.dims))
        if self.global_biasing == 'additive':
            if diagonal_init:
                L = self.get_fourier('L',
                                     *self.nd,
                                     diags=[
                                         self.diag_L(n, k)
                                         for n, k in zip(self.nd, kd_init)
                                     ])
                b = L(self.r2c(self._circular_pad(filt)))
            else:
                b = self.r2c(torch.zeros(1, 1, *self.in_size))
        elif self.global_biasing == 'interp':
            if diagonal_init:
                b = self.r2c(torch.cat((torch.ones(1), filt.flatten())))
            else:
                b = self.r2c(torch.zeros(1 + np.prod(self.max_kernel_size)))
        else:
            b = self.r2c(torch.Tensor(0))
        if _swap:
            self.b.data = b.to(self.b.device)
        else:
            self.register_parameter('b', nn.Parameter(b))

        self.channel_gating = channel_gating
        if self.channel_gating == 'complete':
            if diagonal_init:
                C = self.r2c(torch.zeros(self.chan))
                C[outoff:outoff + channels,
                  inoff:inoff + channels] = torch.eye(channels)
            else:
                C = self.r2c(torch.ones(self.chan))
        elif self.channel_gating == 'interp':
            C = self.r2c(torch.Tensor([float(diagonal_init)]))
        else:
            C = self.r2c(torch.Tensor(0))
        if _swap:
            self.C.data = C.to(self.C.device)
        else:
            self.register_parameter('C', nn.Parameter(C))

        self.to(self.device())
예제 #6
0
def acdc(diag1: torch.Tensor,
         diag2: torch.Tensor,
         dct_first: bool = True,
         separate_diagonal: bool = True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication:
        x -> diag2 @ iDCT @ diag1 @ DCT @ x
    or
        x -> diag2 @ DCT @ diag1 @ iDCT @ x.
    In the paper [1], the math describes the 2nd type while the implementation uses the 1st type.
    Note that the DCT and iDCT are normalized.
    [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas.
    ACDC: A Structured Efficient Linear Layer.
    http://arxiv.org/abs/1511.05946
    Parameters:
        diag1: (n,), where n is a power of 2.
        diag2: (n,), where n is a power of 2.
        dct_first: if True, uses the first type above; otherwise use the second type.
        separate_diagonal: if False, the diagonal is combined into the Butterfly part.
    """
    n, = diag1.shape
    assert diag2.shape == (n, )
    assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2'
    # Construct the permutation before the FFT: separate the even and odd and then reverse the odd
    # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1].
    # This permutation is actually in B (not just B^T B or B B^T). This can be checked with
    # perm2butterfly.
    perm = torch.arange(n)
    perm = torch.cat((perm[::2], perm[1::2].flip([0])))
    perm_inverse = invert(perm)
    br = bitreversal_permutation(n, pytorch_format=True)
    postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) /
                                     (2 * n))
    # Normalize
    postprocess_diag[0] /= 2.0
    postprocess_diag[1:] /= math.sqrt(2)
    if dct_first:
        b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False)
        b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False)
        b2 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj()[br],
                                diag_first=True)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False)
            b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False)
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(), Real2Complex(), b2,
                                 Complex2Real(),
                                 FixedPermutation(perm_inverse))
        else:
            return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1,
                                 Complex2Real(),
                                 Diagonal(diagonal_init=diag1[br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2[perm]),
                                 FixedPermutation(perm_inverse))
    else:
        b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False)
        b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False)
        b1 = diagonal_butterfly(b_ifft,
                                postprocess_diag.conj(),
                                diag_first=True)
        b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False)
        if not separate_diagonal:
            b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False)
            b2 = diagonal_butterfly(b_fft, diag2, diag_first=False)
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Real2Complex(), b2, Complex2Real())
        else:
            return nn.Sequential(Real2Complex(), b1, Complex2Real(),
                                 Diagonal(diagonal_init=diag1[perm][br]),
                                 Real2Complex(), b2, Complex2Real(),
                                 Diagonal(diagonal_init=diag2))
예제 #7
0
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module:
    """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix
    multiplication.
    Parameters:
        col: torch.Tensor of size (n, ). The first column of the circulant matrix.
        transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row*
                    of the matrix.
        separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly.
                           if False, the diagonal is combined into the Butterfly part.
    """
    assert col.dim() == 1, 'Vector col must have dimension 1'
    complex = col.is_complex()
    n = col.shape[0]
    log_n = int(math.ceil(math.log2(n)))
    # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n?
    # I've only figured out how to pad to size 1 << (log_n + 1).
    # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c]
    n_extended = n if n == 1 << log_n else 1 << (log_n + 1)
    b_fft = fft(n_extended,
                normalized=True,
                br_first=False,
                with_br_perm=False).to(col.device)
    b_fft.in_size = n
    b_ifft = ifft(n_extended,
                  normalized=True,
                  br_first=True,
                  with_br_perm=False).to(col.device)
    b_ifft.out_size = n
    if n < n_extended:
        col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n)))
        col = torch.cat((col_0, col))
    if not col.is_complex():
        col = real2complex(col)
    # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the
    # circulant matrix.
    col_f = torch.fft.fft(col, norm=None)
    if transposed:
        # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT.
        # Instead we use the fact that row is the reverse of col, but the 0-th element stays put.
        # This corresponds to the same reversal in the frequency domain.
        # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
        col_f = torch.cat((col_f[:1], col_f[1:].flip([0])))
    br_perm = bitreversal_permutation(n_extended,
                                      pytorch_format=True).to(col.device)
    # diag = col_f[..., br_perm]
    diag = index_last_dim(col_f, br_perm)
    if separate_diagonal:
        if not complex:
            return nn.Sequential(Real2Complex(), b_fft,
                                 Diagonal(diagonal_init=diag), b_ifft,
                                 Complex2Real())
        else:
            return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft)
    else:
        # Combine the diagonal with the last twiddle factor of b_fft
        with torch.no_grad():
            b_fft = diagonal_butterfly(b_fft,
                                       diag,
                                       diag_first=False,
                                       inplace=True)
        # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2).
        b = butterfly_product(b_fft, b_ifft)
        b.in_size = n
        b.out_size = n
        return b if complex else nn.Sequential(Real2Complex(), b,
                                               Complex2Real())