Exemplo n.º 1
def coldfilt(X, ha, hb, highpass=False, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1,1,1,1, device=X.device)
    batch, ch, r, c = X.shape
    r2 = r // 2
    if r % 4 != 0:
        raise ValueError('No. of rows in X must be a multiple of 4\n' +
                         'X was {}'.format(X.shape))

    if mode == 'symmetric':
        m = ha.shape[2]
        xe = symm_pad(r, m)
        X1 = X[:,:,xe[2::2]]
        X2 = X[:,:,xe[3::2]]
        h = torch.cat((ha.repeat(ch, 1, 1, 1), hb.repeat(ch, 1, 1, 1)), dim=0)
        Y = F.conv2d(torch.cat((X1, X2), dim=1), h, stride=(2,1), groups=ch*2)
        Y1 = Y[:, :ch]
        Y2 = Y[:, ch:]
        del Y
        raise NotImplementedError()

    # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping
    # interleaves the columns
    if highpass:
        Y = torch.stack((Y2, Y1), dim=-2).view(batch, ch, r2, c)
        Y = torch.stack((Y1, Y2), dim=-2).view(batch, ch, r2, c)

    return Y
Exemplo n.º 2
def colfilter(X, h):
    if X is None or X.shape == torch.Size([0]):
        return torch.zeros(1,1,1,1, device=X.device)
    ch, r = X.shape[1:3]
    m = h.shape[2] // 2
    xe = symm_pad(r, m)
    return F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch)
Exemplo n.º 3
def rowdfilt(X, ha, hb, highpass=False, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1, 1, 1, 1, device=X.device)
    batch, ch, r, c = X.shape
    c2 = c // 2
    if c % 4 != 0:
        raise ValueError('No. of cols in X must be a multiple of 4\n' +
                         'X was {}'.format(X.shape))

    if mode == 'symmetric':
        m = ha.shape[2]
        xe = symm_pad(c, m)
        X = torch.cat((X[:, :, :, xe[2::2]], X[:, :, :, xe[3::2]]), dim=1)
        h = torch.cat(
            (ha.reshape(1, 1, 1, m).repeat(ch, 1, 1, 1), hb.reshape(
                1, 1, 1, m).repeat(ch, 1, 1, 1)),
        X = F.conv2d(X, h, stride=(1, 2), groups=ch * 2)
        raise NotImplementedError()

    # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping
    # interleaves the columns
    if highpass:
        Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2)
        Y = torch.stack((X[:, :ch], X[:, ch:]), dim=-1).view(batch, ch, r, c2)

    return Y
Exemplo n.º 4
def rowfilter(X, h):
    if X is None or X.shape == torch.Size([0]):
        return torch.zeros(1,1,1,1, device=X.device)
    ch, _, c = X.shape[1:]
    m = h.shape[2] // 2
    xe = symm_pad(c, m)
    h = h.transpose(2,3).contiguous()
    return F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch)
Exemplo n.º 5
def colfilter(X, h, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1,1,1,1, device=X.device)
    ch, r = X.shape[1:3]
    m = h.shape[2] // 2
    if mode == 'symmetric':
        xe = symm_pad(r, m)
        y = F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch)
        y = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(m, 0))
    return y
Exemplo n.º 6
def colifilt(X, ha, hb, highpass=False, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1, 1, 1, 1, device=X.device)
    m = ha.shape[2]
    m2 = m // 2
    hao = ha[:, :, 1::2]
    hae = ha[:, :, ::2]
    hbo = hb[:, :, 1::2]
    hbe = hb[:, :, ::2]
    batch, ch, r, c = X.shape
    if r % 2 != 0:
        raise ValueError('No. of rows in X must be a multiple of 2.\n' +
                         'X was {}'.format(X.shape))
    xe = symm_pad(r, m2)

    if m2 % 2 == 0:
        h1 = hae
        h2 = hbe
        h3 = hao
        h4 = hbo
        if highpass:
            X = torch.cat((X[:, :, xe[1:-2:2]], X[:, :, xe[:-2:2]],
                           X[:, :, xe[3::2]], X[:, :, xe[2::2]]),
            X = torch.cat((X[:, :, xe[:-2:2]], X[:, :, xe[1:-2:2]],
                           X[:, :, xe[2::2]], X[:, :, xe[3::2]]),
        h1 = hao
        h2 = hbo
        h3 = hae
        h4 = hbe
        if highpass:
            X = torch.cat((X[:, :, xe[2:-1:2]], X[:, :, xe[1:-1:2]],
                           X[:, :, xe[2:-1:2]], X[:, :, xe[1:-1:2]]),
            X = torch.cat((X[:, :, xe[1:-1:2]], X[:, :, xe[2:-1:2]],
                           X[:, :, xe[1:-1:2]], X[:, :, xe[2:-1:2]]),
    h = torch.cat(
        (h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), h3.repeat(
            ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)),

    X = F.conv2d(X, h, groups=4 * ch)
    # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor
    # [batch, ch, r2, 4, c]
    X = torch.stack(
        [X[:, :ch], X[:, ch:2 * ch], X[:, 2 * ch:3 * ch], X[:, 3 * ch:]],
        dim=3).view(batch, ch, r * 2, c)

    return X
Exemplo n.º 7
def rowfilter(X, h, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1,1,1,1, device=X.device)
    ch, _, c = X.shape[1:]
    m = h.shape[2] // 2
    h = h.transpose(2,3).contiguous()
    if mode == 'symmetric':
        xe = symm_pad(c, m)
        y = F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch)
        y = F.conv2d(X, h.repeat(ch,1,1,1), groups=ch, padding=(0, m))
    return y
Exemplo n.º 8
def rowifilt(X, ha, hb, highpass=False, mode='symmetric'):
    if X is None or X.shape == torch.Size([]):
        return torch.zeros(1, 1, 1, 1, device=X.device)
    m = ha.shape[2]
    m2 = m // 2
    hao = ha[:, :, 1::2]
    hae = ha[:, :, ::2]
    hbo = hb[:, :, 1::2]
    hbe = hb[:, :, ::2]
    batch, ch, r, c = X.shape
    if c % 2 != 0:
        raise ValueError('No. of cols in X must be a multiple of 2.\n' +
                         'X was {}'.format(X.shape))
    xe = symm_pad(c, m2)

    if m2 % 2 == 0:
        h1 = hae
        h2 = hbe
        h3 = hao
        h4 = hbo
        X1 = X[:, :, :, xe[:-2:2]]
        X2 = X[:, :, :, xe[1:-2:2]]
        X3 = X[:, :, :, xe[2::2]]
        X4 = X[:, :, :, xe[3::2]]
        h1 = hao
        h2 = hbo
        h3 = hae
        h4 = hbe
        X1 = X[:, :, :, xe[1:-1:2]]
        X2 = X[:, :, :, xe[2:-1:2]]
        X3 = X[:, :, :, xe[1:-1:2]]
        X4 = X[:, :, :, xe[2:-1:2]]
    if highpass:
        X2, X1 = X1, X2
        X4, X3 = X3, X4
    h = torch.cat(
        (h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), h3.repeat(
            ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)),
        dim=0).reshape(4 * ch, 1, 1, m2)
    X = torch.cat((X1, X2, X3, X4), dim=1)

    Y = F.conv2d(X, h, groups=4 * ch)
    # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor
    # [batch, ch, r2, 4, c]
    Y = torch.stack(
        [Y[:, :ch], Y[:, ch:2 * ch], Y[:, 2 * ch:3 * ch], Y[:, 3 * ch:]],
        dim=4).view(batch, ch, r, c * 2)
    return Y