示例#1
0
def test_equal_oddshape(size):
    wave = 'db3'
    J = 3
    mode = 'symmetric'
    x = torch.randn(5, 4, *size).to(dev)
    dwt1 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt1 = DWTInverse(wave=wave, mode=mode).to(dev)
    dwt2 = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt2 = DWTInverse(wave=wave, mode=mode).to(dev)

    yl1, yh1 = dwt1(x)
    x1 = iwt1((yl1, yh1))
    yl2, yh2 = dwt2(x)
    x2 = iwt2((yl2, yh2))

    # Test it is the same as doing the PyWavelets wavedec
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    X2 = pywt.waverec2(coeffs, wave, mode=mode)
    np.testing.assert_array_almost_equal(X2, x1.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(X2, x2.detach(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl1.cpu(), coeffs[0], decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(yl2.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh1[j][:,:,b].cpu(), decimal=PREC_FLT)
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh2[j][:,:,b].cpu(), decimal=PREC_FLT)
示例#2
0
def test_inv_j2(mode):
    with set_double_precision():
        low = torch.randn(1, 3, 16, 16, device=dev, requires_grad=True)
        high = torch.randn(1, 3, 3, 16, 16, device=dev, requires_grad=True)
        ifm = DWTInverse().to(dev)
    input = (low, high, ifm.g0_row, ifm.g1_row, ifm.g0_col, ifm.g1_col, mode)
    gradcheck(SFB2D.apply, input, eps=EPS, atol=ATOL)
示例#3
0
def test_commutativity(wave, J, j):
    # Test the commutativity of the dwt
    C = 3
    Y = torch.randn(4, C, 128, 128, requires_grad=True, device=dev)
    dwt = DWTForward(J=J, wave=wave).to(dev)
    iwt = DWTInverse(wave=wave).to(dev)

    coeffs = dwt(Y)
    coeffs_zero = dwt(torch.zeros_like(Y))
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = coeffs[1][j][:,:,0]
    ya = iwt(coeffs_zero)
    # Set level j HL to also be nonzero
    coeffs_zero[1][j][:,:,1] = coeffs[1][j][:,:,1]
    yab = iwt(coeffs_zero)
    # Set level j LH to be nonzero
    coeffs_zero[1][j][:,:,0] = torch.zeros_like(coeffs[1][j][:,:,0])
    yb = iwt(coeffs_zero)
    # Set level j HH to also be nonzero
    coeffs_zero[1][j][:,:,2] = coeffs[1][j][:,:,2]
    ybc = iwt(coeffs_zero)
    # Set level j HL to be nonzero
    coeffs_zero[1][j][:,:,1] = torch.zeros_like(coeffs[1][j][:,:,1])
    yc = iwt(coeffs_zero)

    np.testing.assert_array_almost_equal(
        (ya+yb).detach().cpu(), yab.detach().cpu(), decimal=PREC_FLT)
    np.testing.assert_array_almost_equal(
        (yc+yb).detach().cpu(), ybc.detach().cpu(), decimal=PREC_FLT)
示例#4
0
def test_gradients_fwd(wave, J, mode):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5,6,128, 128).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)

    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    yl, yh = dwt(imt)

    # Test the lowpass
    ylg = torch.randn(*yl.shape, device=dev)
    yl.backward(ylg, retain_graph=True)
    zeros = [torch.zeros_like(yh[i]) for i in range(J)]
    ref = iwt((ylg, zeros))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j, y in enumerate(yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [zeros[i] for i in range(J)]
        hps[j] = g
        ref = iwt((torch.zeros_like(yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu(),
                                             decimal=PREC_FLT)
示例#5
0
def test_gradients_inv(wave, J, mode):
    """ Gradient of inverse function should be forward function with filters
    swapped """
    wave = pywt.Wavelet(wave)
    fwd_filts = (wave.dec_lo, wave.dec_hi)
    inv_filts = (wave.dec_lo[::-1], wave.dec_hi[::-1])
    dwt = DWTForward(J=J, wave=fwd_filts, mode=mode).to(dev)
    iwt = DWTInverse(wave=inv_filts, mode=mode).to(dev)

    # Get the shape of the pyramid
    temp = torch.zeros(5,6,128,128).to(dev)
    l, h = dwt(temp)
    # Create our inputs
    yl = torch.randn(*l.shape, requires_grad=True, device=dev)
    yh = [torch.randn(*h[i].shape, requires_grad=True, device=dev)
          for i in range(J)]
    y = iwt((yl, yh))

    # Test the gradients
    yg = torch.randn(*y.shape, device=dev)
    y.backward(yg, retain_graph=True)
    dyl, dyh = dwt(yg)

    # test the lowpass
    np.testing.assert_array_almost_equal(yl.grad.detach().cpu(), dyl.cpu(),
                                         decimal=PREC_FLT)

    # Test the bandpass
    for j in range(J):
        np.testing.assert_array_almost_equal(yh[j].grad.detach().cpu(),
                                             dyh[j].cpu(),
                                             decimal=PREC_FLT)
示例#6
0
    def __init__(self,
                 C,
                 F,
                 lp_size=3,
                 bp_sizes=(1, ),
                 q=1.0,
                 wave='db2',
                 mode='zero',
                 xfm=True,
                 ifm=True):
        super().__init__()
        self.C = C
        self.F = F
        self.J = len(bp_sizes)

        if xfm:
            self.XFM = DWTForward(J=self.J, mode=mode, wave=wave)
        else:
            self.XFM = lambda x: x

        if q < 0:
            self.shrink = ReLUWaveCoeffs()
        else:
            self.shrink = lambda x: x

        self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes)

        if ifm:
            self.IFM = DWTInverse(mode=mode, wave=wave)
        else:
            self.IFM = lambda x: x
示例#7
0
def test_equal_double(wave, J, mode):
    with set_double_precision():
        x = torch.randn(5, 4, 64, 64).to(dev)
        assert x.dtype == torch.float64
        dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
        iwt = DWTInverse(wave=wave, mode=mode).to(dev)

    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(),
                                         x2.detach().cpu(),
                                         decimal=PREC_DBL)
    coeffs = pywt.wavedec2(x.cpu().numpy(),
                           wave,
                           level=J,
                           axes=(-2, -1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=7)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(coeffs[J - j][b],
                                                 yh[j][:, :, b].cpu(),
                                                 decimal=PREC_DBL)
示例#8
0
 def __init__(self, J=3, wave='db4', device='cpu', dtype=torch.float64):
     super().__init__()
     _dtype = torch.get_default_dtype()
     torch.set_default_dtype(dtype)
     self.dwt = DWTForward(J=J, wave=wave, mode='per').to(device)
     self.idwt = DWTInverse(wave=wave, mode='per').to(device)
     torch.set_default_dtype(_dtype)
     self.norm_bound = 1.
示例#9
0
def test_ok(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))
    # Can have data errors sometimes
    assert yl.is_contiguous()
    for j in range(J):
        assert yh[j].is_contiguous()
    assert x2.is_contiguous()
示例#10
0
 def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True):
     super().__init__()
     self.wd = wd
     if wd1 is None:
         self.wd1 = wd
     else:
         self.wd1 = wd1
     self.C = C
     self.F = F
     x = torch.zeros(F, C, k, k)
     torch.nn.init.xavier_uniform_(x)
     xfm = DWTForward(J=J)
     self.ifm = DWTInverse()
     yl, yh = xfm(x)
     self.J = J
     if k == 4 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (1, 2, 1, 2)
         else:
             self.pad = (2, 1, 2, 1)
     elif k == 8 and J == 1:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[0]))
         self.gl.data = yl.data
         self.gh.data = yh[0].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 2:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[1]))
         self.gl.data = yl.data
         self.gh.data = yh[1].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     elif k == 8 and J == 3:
         self.gl = nn.Parameter(torch.zeros_like(yl))
         self.gh = nn.Parameter(torch.zeros_like(yh[2]))
         self.gl.data = yl.data
         self.gh.data = yh[2].data
         if right:
             self.pad = (3, 4, 3, 4)
         else:
             self.pad = (4, 3, 4, 3)
     else:
         raise NotImplementedError
示例#11
0
    def __init__(
        self,
        num_levels: int = 1,
        wave: str = 'db2',
        mode: str = 'periodization',
        device: Optional[str] = None,
    ) -> None:
        self.num_levels = num_levels
        self.wave = wave
        self.mode = mode

        self.xfm = DWTForward(J=num_levels, wave=wave, mode=mode).to(device)
        self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
示例#12
0
 def __init__(self, in_ch, internal_ch=None, filter_sz=3, num_conv=3):
     super(IWCNN, self).__init__()
     if internal_ch is None:
         internal_ch = in_ch
     self.IDwT = DWTInverse(wave='haar', mode='zero')
     modules = []
     for i in range(num_conv):
         modules.append(nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1))
         #modules.append(nn.BatchNorm2d(num_features=in_ch))
         modules.append(nn.ReLU())
     modules.append(nn.Conv2d(in_ch, internal_ch, kernel_size=3, padding=1))
     #modules.append(nn.BatchNorm2d(num_features=internal_ch))
     modules.append(nn.ReLU())
     self.conv = nn.Sequential(*modules)
示例#13
0
 def __init__(self,
              discriminator_size,
              perceptual_size=256,
              generator_weights=[1.0, 1.0, 1.0, 0.5, 0.1]):
     super().__init__()
     # l1/l2 loss
     self.l1_loss = nn.L1Loss()
     self.l2_loss = nn.MSELoss()
     # perceptual_loss
     self.perceptual_loss = PerceptualLoss(perceptual_size)
     # gan loss
     self.gan_loss = GANLoss(image_size=int(discriminator_size))
     self.generator_weights = generator_weights
     self.dwt = DWTForward(J=1, mode='zero', wave='db1')
     self.idwt = DWTInverse(mode="zero", wave="db1")
示例#14
0
def test_equal(wave, J, mode):
    x = torch.randn(5, 4, 64, 64).to(dev)
    dwt = DWTForward(J=J, wave=wave, mode=mode).to(dev)
    iwt = DWTInverse(wave=wave, mode=mode).to(dev)
    yl, yh = dwt(x)
    x2 = iwt((yl, yh))

    # Test the forward and inverse worked
    np.testing.assert_array_almost_equal(x.cpu(), x2.detach(), decimal=PREC_FLT)
    # Test it is the same as doing the PyWavelets wavedec with reflection
    # padding
    coeffs = pywt.wavedec2(x.cpu().numpy(), wave, level=J, axes=(-2,-1),
                           mode=mode)
    np.testing.assert_array_almost_equal(yl.cpu(), coeffs[0], decimal=PREC_FLT)
    for j in range(J):
        for b in range(3):
            np.testing.assert_array_almost_equal(
                coeffs[J-j][b], yh[j][:,:,b].cpu(), decimal=PREC_FLT)
示例#15
0
 def __init__(self,
              wt_type='DTCWT',
              biort='near_sym_b',
              qshift='qshift_b',
              J=5,
              wave='db3',
              mode='zero',
              device='cuda',
              requires_grad=True):
     super().__init__()
     if wt_type == 'DTCWT':
         self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device)
         self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device)
     elif wt_type == 'DWT':
         self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device)
         self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
     else:
         raise ValueError('no such type of wavelet transform is supported')
     self.J = J
     self.wt_type = wt_type
示例#16
0
def init_dwt(resume=None, shape=None, wave=None, colors=None):
    size = None
    wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]),
                                   wavelet='db1',
                                   mode='symmetric')
    xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda()
    # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    ifm = DWTInverse(wave=wave,
                     mode='symmetric').cuda()  # symmetric zero periodization
    # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
    if resume is None:  # random init
        Yl_in, Yh_in = xfm(torch.zeros(shape).cuda())
        Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]]
    elif isinstance(resume, str):
        if os.path.isfile(resume):
            if os.path.splitext(resume)[1].lower()[1:] in [
                    'jpg', 'png', 'tif', 'bmp'
            ]:
                img_in = imread(resume)
                Ys = img2dwt(img_in, wave=wave, colors=colors)
                print(' loaded image', resume, img_in.shape, 'level',
                      len(Ys) - 1)
                size = img_in.shape[:2]
                wp_fake = pywt.WaveletPacket2D(data=np.zeros(size),
                                               wavelet='db1',
                                               mode='symmetric')
                xfm = DWTForward(J=wp_fake.maxlevel,
                                 wave=wave,
                                 mode='symmetric').cuda()
            else:
                Ys = torch.load(resume)
                Ys = [y.detach().cuda() for y in Ys]
        else:
            print(' Snapshot not found:', resume)
            exit()
    else:
        Ys = [y.cuda() for y in resume]
    # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape)
    return Ys, xfm, ifm, size
示例#17
0
        netG.apply(weights_init)
else:
    netG.apply(weights_init)

# setup optimizer
if opt.optimizer == 'adam':
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
elif opt.optimizer == 'sgd':
    optimizerG = optim.SGD(netG.parameters(), lr=opt.lr, momentum=0.9)

criterion_pre = nn.CrossEntropyLoss()
criterion_pre = criterion_pre.cuda(gpulist[0])

# create forward and inverse wavelet transform
xfm = DWTForward(J=3, wave='db3', mode='symmetric').cuda(gpulist[0])
ifm = DWTInverse(wave='db3', mode='symmetric').cuda(gpulist[0])

# Penalty coef for conditional loss
dssimCoef = {
   0.05: 3.0,
   0.1: 2.0,
   0.2: 1.0,
   0.3: 0.5
}

#________________________________________________________________________________
#
## @brief conditional loss function with budget control
#
#________________________________________________________________________________
def conditionalBudgetLoss(perturbedX, x, outputLabel, targetLabel):
示例#18
0
 def __init__(self, opt):
     super(DSWN, self).__init__()
     self.DWT = DWTForward(J=1, wave='haar').cuda()
     self.IDWT = DWTInverse(wave='haar').cuda()
     # The generator is U shaped
     # Encoder
     self.E1 = Conv2dLayer(in_channels=3,
                           out_channels=160,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.E2 = Conv2dLayer(in_channels=3 * 4,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.E3 = Conv2dLayer(in_channels=3 * 4 * 4,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.E4 = Conv2dLayer(in_channels=3 * 4 * 16,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     # Bottle neck
     self.BottleNeck = nn.Sequential(
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm))
     self.blockDMT11 = self.make_layer(_DCR_block, 320)
     self.blockDMT12 = self.make_layer(_DCR_block, 320)
     self.blockDMT13 = self.make_layer(_DCR_block, 320)
     self.blockDMT14 = self.make_layer(_DCR_block, 320)
     self.blockDMT21 = self.make_layer(_DCR_block, 512)
     # self.blockDMT22 = self.make_layer(_DCR_block, 512)
     # self.blockDMT23 = self.make_layer(_DCR_block, 512)
     # self.blockDMT24 = self.make_layer(_DCR_block, 512)
     self.blockDMT31 = self.make_layer(_DCR_block, 512)
     # self.blockDMT32 = self.make_layer(_DCR_block, 512)
     # self.blockDMT33 = self.make_layer(_DCR_block, 512)
     # self.blockDMT34 = self.make_layer(_DCR_block, 512)
     self.blockDMT41 = self.make_layer(_DCR_block, 256)
     # self.blockDMT42 = self.make_layer(_DCR_block, 256)
     # self.blockDMT43 = self.make_layer(_DCR_block, 256)
     # self.blockDMT44 = self.make_layer(_DCR_block, 256)
     # self.DRB11 = ResidualDenseBlock_5C(nf=320, gc=64)
     # self.DRB12 = ResidualDenseBlock_5C(nf=320, gc=64)
     # self.DRB21 = ResidualDenseBlock_5C(nf=512, gc=64)
     # self.DRB31 = ResidualDenseBlock_5C(nf=512, gc=64)
     # self.DRB41 = ResidualDenseBlock_5C(nf=256, gc=64)
     # Decoder
     self.D1 = Conv2dLayer(in_channels=256,
                           out_channels=1024,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.D2 = Conv2dLayer(in_channels=512,
                           out_channels=1024,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.D3 = Conv2dLayer(in_channels=512,
                           out_channels=640,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='prelu',
                           norm=opt.norm)
     self.D4 = Conv2dLayer(in_channels=320,
                           out_channels=3,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           norm='none',
                           activation='tanh')
     self.D5 = Conv2dLayer(in_channels=320,
                           out_channels=3,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           norm='none',
                           activation='tanh')
     # channel shuffle
     self.S1 = Conv2dLayer(in_channels=512,
                           out_channels=512,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='none',
                           norm='none')
     self.S2 = Conv2dLayer(in_channels=512,
                           out_channels=512,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='none',
                           norm='none')
     self.S3 = Conv2dLayer(in_channels=320,
                           out_channels=320,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='none',
                           norm='none')
     self.S4 = Conv2dLayer(in_channels=320,
                           out_channels=320,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           dilation=1,
                           groups=3 * 320,
                           pad_type=opt.pad,
                           activation='none',
                           norm='none')
示例#19
0
 def __init__(self, opt):
     super(MWCNN, self).__init__()
     self.DWT = DWTForward(J=1, wave='haar').cuda()
     self.IDWT = DWTInverse(wave='haar').cuda()
     # The generator is U shaped
     # Encoder
     self.E1 = Conv2dLayer(in_channels=3 * 4,
                           out_channels=160,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.E2 = Conv2dLayer(in_channels=640,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.E3 = Conv2dLayer(in_channels=1024,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.E4 = Conv2dLayer(in_channels=1024,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     # Bottle neck
     self.BottleNeck = nn.Sequential(
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm),
         ResConv2dLayer(256, 3, 1, 1, pad_type=opt.pad, norm=opt.norm))
     self.blockDMT1 = self.make_layer(Block_of_DMT1, 3)
     self.blockDMT2 = self.make_layer(Block_of_DMT2, 3)
     self.blockDMT3 = self.make_layer(Block_of_DMT3, 3)
     self.blockDMT4 = self.make_layer(Block_of_DMT4, 3)
     # Decoder
     self.D1 = Conv2dLayer(in_channels=256,
                           out_channels=1024,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.D2 = Conv2dLayer(in_channels=256,
                           out_channels=1024,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.D3 = Conv2dLayer(in_channels=256,
                           out_channels=640,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           activation='relu',
                           norm=opt.norm)
     self.D4 = Conv2dLayer(in_channels=160,
                           out_channels=3 * 4,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           dilation=1,
                           pad_type=opt.pad,
                           norm='none',
                           activation='tanh')
示例#20
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_features,
                 num_simdb,
                 upscale_factor,
                 act_type='prelu',
                 norm_type=None):
        super(WSR, self).__init__()

        padding = 2
        self.num_features = num_features
        self.upscale_factor = upscale_factor
        self.num_steps = int(np.log2(self.upscale_factor) + 1)

        # LR feature extraction block
        self.conv_in = ConvBlock(in_channels,
                                 4 * num_features,
                                 kernel_size=3,
                                 act_type=act_type,
                                 norm_type=norm_type)
        self.feat_in = ConvBlock(4 * num_features,
                                 num_features,
                                 kernel_size=1,
                                 act_type=act_type,
                                 norm_type=norm_type)

        # recurrent block
        self.rb = RecurrentBlock(num_features, num_simdb, act_type, norm_type)

        # reconstruction block
        self.conv_steps = nn.ModuleList([
            nn.Sequential(
                ConvBlock(num_features,
                          num_features,
                          kernel_size=3,
                          act_type=act_type,
                          norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type)),
            nn.Sequential(
                ConvBlock(num_features,
                          num_features,
                          kernel_size=3,
                          act_type=act_type,
                          norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels * 3,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type))
        ])
        for step in range(2, self.num_steps):
            conv_step = nn.Sequential(
                DeconvBlock(num_features,
                            num_features,
                            kernel_size=int(2**(step - 1) + 4),
                            stride=int(2**(step - 1)),
                            padding=padding,
                            act_type=act_type,
                            norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels * 3,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type))
            self.conv_steps.append(conv_step)

        # inverse wavelet transformation
        self.ifm = DWTInverse(wave='db1', mode='symmetric').eval()
        for k, v in self.ifm.named_parameters():
            v.requires_grad = False
示例#21
0
class WSR(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_features,
                 num_simdb,
                 upscale_factor,
                 act_type='prelu',
                 norm_type=None):
        super(WSR, self).__init__()

        padding = 2
        self.num_features = num_features
        self.upscale_factor = upscale_factor
        self.num_steps = int(np.log2(self.upscale_factor) + 1)

        # LR feature extraction block
        self.conv_in = ConvBlock(in_channels,
                                 4 * num_features,
                                 kernel_size=3,
                                 act_type=act_type,
                                 norm_type=norm_type)
        self.feat_in = ConvBlock(4 * num_features,
                                 num_features,
                                 kernel_size=1,
                                 act_type=act_type,
                                 norm_type=norm_type)

        # recurrent block
        self.rb = RecurrentBlock(num_features, num_simdb, act_type, norm_type)

        # reconstruction block
        self.conv_steps = nn.ModuleList([
            nn.Sequential(
                ConvBlock(num_features,
                          num_features,
                          kernel_size=3,
                          act_type=act_type,
                          norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type)),
            nn.Sequential(
                ConvBlock(num_features,
                          num_features,
                          kernel_size=3,
                          act_type=act_type,
                          norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels * 3,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type))
        ])
        for step in range(2, self.num_steps):
            conv_step = nn.Sequential(
                DeconvBlock(num_features,
                            num_features,
                            kernel_size=int(2**(step - 1) + 4),
                            stride=int(2**(step - 1)),
                            padding=padding,
                            act_type=act_type,
                            norm_type=norm_type),
                ConvBlock(num_features,
                          out_channels * 3,
                          kernel_size=3,
                          act_type=None,
                          norm_type=norm_type))
            self.conv_steps.append(conv_step)

        # inverse wavelet transformation
        self.ifm = DWTInverse(wave='db1', mode='symmetric').eval()
        for k, v in self.ifm.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        self._reset_state()

        x = self.conv_in(x)
        x = self.feat_in(x)

        Yl = self.conv_steps[0](x)

        Yh = []
        for step in range(1, self.num_steps):
            h = self.rb(x)
            h = self.conv_steps[step](h)
            h = h.view(h.size()[0],
                       h.size()[1] // 3, 3,
                       h.size()[2],
                       h.size()[3])
            Yh = [h] + Yh

        sr = self.ifm((Yl, Yh))

        # return [Yl, Yh, sr]
        return sr

    def _reset_state(self):
        self.rb.reset_state()
示例#22
0
Yl, Yh = DWTForward(J=J_spec, mode='periodization', wave='db3')(img)
print(Yl.shape, [h.shape for h in Yh])

imgLR = F.interpolate(img, scale_factor=.5)
LQYl, LQYh = DWTForward(J=J_spec-1, mode='periodization', wave='db3')(imgLR)
print(LQYl.shape, [h.shape for h in LQYh])

for i in range(J_spec):
    smd = torch.sum(Yh[i], dim=2).cpu()
    save_img(smd, "high_%i.png" % (i,))
save_img(Yl, "lo.png")

'''
Following code reconstructs the image with different high passes cancelled out.
'''
for i in range(J_spec):
    corrupted_im = [y for y in Yh]
    corrupted_im[i] = torch.zeros_like(corrupted_im[i])
    im = DWTInverse(mode='periodization', wave='db3')((Yl, corrupted_im))
    save_img(im, "corrupt_%i.png" % (i,))
im = DWTInverse(mode='periodization', wave='db3')((torch.full_like(Yl, fill_value=torch.mean(Yl)), Yh))
save_img(im, "corrupt_im.png")


'''
Following code reconstructs a hybrid image with the first high pass from the HR and the rest of the data from the LR.
highpass = [Yh[0]] + LQYh
im = DWTInverse(mode='periodization', wave='db3')((LQYl, highpass))
save_img(im, "hybrid_lrhr.png")
save_img(F.interpolate(imgLR, scale_factor=2), "upscaled.png")
'''
示例#23
0
def perturb_img_bands_only(img,
                           label,
                           target,
                           model,
                           attack_type,
                           dataset=None,
                           net_type=None,
                           alpha=0,
                           max_iters=20,
                           epsilon=8.0 / 255.0,
                           random_restarts=False,
                           update_type='gradient',
                           step_size=2.0 / 255.0,
                           bands='ll',
                           lr=5e-2,
                           odi=False,
                           loss_type='xent',
                           class2idx=None):

    if bands:
        print('Using bands only:', bands)

    print('Loss fn:', loss_type)
    print('Update:', update_type)

    # idx2class = dict()
    # for key in class2idx:
    # 	idx2class[str(class2idx[key])] = int(key)

    if dataset == 'imagenet':
        for i in range(len(target)):
            # print('before', target[i], idx2class[str(target[i].data.cpu().numpy())])
            target[i] = idx2class[str(target[i].data.cpu().numpy())]
            # pritn('after', target[i])

        target = target.cuda()

    print(target)

    if random_restarts: num_restarts = 20
    else: num_restarts = 1
    if odi: odi_steps = 2
    else: odi_steps = 0

    if dataset == 'multi-pie' and net_type == 'vggface2':
        epsilon = 8.0
        lr = 2.0

    if dataset == 'multi-pie' and net_type == 'vggface2-kernel-def':
        print('adjusted epsilon')
        epsilon = 8.0
        lr = 2.0

    odi_step_size = epsilon
    best_adv = None
    curr_pred = None
    switch_band = bands

    ifm = DWTInverse(mode='zero', wave='haar').cuda()
    xfm = DWTForward(J=1, mode='zero', wave='haar').cuda()

    # avg_Linf_error = 0

    for restart in range(num_restarts):
        random_init = random_restarts
        print('Restart {}'.format(restart + 1))

        adv = Variable(postprocess(img.clone().detach(), dataset,
                                   net_type).data,
                       requires_grad=True)
        if loss_type == 'xent':
            loss_fn = torch.nn.CrossEntropyLoss()
        elif loss_type == 'margin':
            loss_fn = margin_loss
        if random_init:
            random_noise = torch.FloatTensor(*adv.shape).uniform_(
                -epsilon, epsilon).cuda()
            adv = Variable(adv.data + random_noise, requires_grad=True)

        for step in range(odi_steps + max_iters):
            LL, Y = xfm(adv)
            LH, HL, HH = torch.unbind(Y[0], dim=2)
            LL = Variable(LL.data, requires_grad=True)
            LH = Variable(LH.data, requires_grad=True)
            HL = Variable(HL.data, requires_grad=True)
            HH = Variable(HH.data, requires_grad=True)
            if bands == 'll':
                band_optim = optim.SGD([LL], lr=lr)
            elif bands == 'lh':
                band_optim = optim.SGD([LH], lr=lr)
            elif bands == 'hl':
                band_optim = optim.SGD([HL], lr=lr)
            elif bands == 'hh':
                band_optim = optim.SGD([HH], lr=lr)
            elif bands == 'high':
                band_optim = optim.SGD([LH, HL, HH], lr=lr)
            elif bands == 'all':
                band_optim = optim.SGD([LL, LH, HL, HH], lr=lr)
            band_optim.zero_grad()
            adv = Variable(ifm((LL, [torch.stack((LH, HL, HH), 2)])).data,
                           requires_grad=True)

            band_loss = loss_fn(
                model(
                    preprocess(ifm((LL, [torch.stack((LH, HL, HH), 2)])),
                               dataset, net_type)), target)
            band_loss.backward()
            temp = []

            if bands == 'll':
                for band in [LL]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = temp[0], LH, HL, HH
            elif bands == 'lh':
                for band in [LH]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = LL, temp[0], HL, HH
            elif bands == 'hl':
                for band in [HL]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = LL, LH, temp[0], HH
            elif bands == 'hh':
                for band in [HH]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = LL, LH, HL, temp[0]
            elif bands == 'high':
                for band in [LH, HL, HH]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = LL, temp[0], temp[1], temp[2]
            elif bands == 'all':
                for band in [LL, LH, HL, HH]:
                    band_eta = lr * band.grad.data.sign()
                    band = Variable(band.data + band_eta, requires_grad=True)
                    temp.append(band)
                LL, LH, HL, HH = temp[0], temp[1], temp[2], temp[3]

            adv = ifm((LL, [torch.stack((LH, HL, HH), 2)]))

            eta = torch.clamp(
                adv.data -
                postprocess(img.clone().detach(), dataset, net_type), -epsilon,
                epsilon)
            adv = Variable(
                postprocess(img.clone().detach(), dataset, net_type) + eta,
                requires_grad=True)
            if dataset == 'multi-pie' and net_type == 'vggface2':
                adv = Variable(torch.clamp(adv, 0, 255.0), requires_grad=True)
            elif dataset == 'multi-pie' and net_type == 'vggface2-kernel-def':
                adv = Variable(torch.clamp(adv, 0, 255.0), requires_grad=True)
            else:
                adv = Variable(torch.clamp(adv, 0, 1.0), requires_grad=True)

        with torch.no_grad():
            curr_pred = torch.max(model(preprocess(adv, dataset, net_type)),
                                  1)[1]
            if best_adv == None:
                best_adv = adv.clone().detach()
            else:
                for i in range(len(curr_pred)):
                    if curr_pred[i].data != target[i].data:
                        best_adv[i] = adv[i]

    return preprocess(best_adv, dataset, net_type)
示例#24
0
    def __init__(self):  ##play attention the upscales
        super(MWCNN, self).__init__()

        self.DWT = DWTForward(J=1, wave="haar").cuda()
        self.IDWT = DWTInverse(wave="haar").cuda()

        # DMT1 operation
        # DMT1
        # 因为对原图首先做了一次DWT,导致通道数变为原来的4倍调用保持的feature map的函数可以保持通道的正常
        self.conv_DMT1 = nn.Conv2d(in_channels=3 * 4,
                                   out_channels=160,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
        self.bn_DMT1 = nn.BatchNorm2d(160, affine=True)
        self.relu_DMT1 = nn.ReLU()
        # IDMT1  逆变换
        self.conv_IDMT1 = nn.Conv2d(in_channels=160,
                                    out_channels=3 * 4,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        # feature map 保持
        self.blockDMT1 = self.make_layer(Block_of_DMT1, 3)

        # DMT2 operation
        # DMT2
        self.conv_DMT2 = nn.Conv2d(in_channels=640,
                                   out_channels=256,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
        self.bn_DMT2 = nn.BatchNorm2d(256, affine=True)
        self.relu_DMT2 = nn.ReLU()
        # IDMT2
        self.conv_IDMT2 = nn.Conv2d(in_channels=256,
                                    out_channels=640,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        self.bn_IDMT2 = nn.BatchNorm2d(640, affine=True)
        self.relu_IDMT2 = nn.ReLU()

        self.blockDMT2 = self.make_layer(Block_of_DMT2, 3)

        # DMT3 operation
        # DMT3
        self.conv_DMT3 = nn.Conv2d(in_channels=1024,
                                   out_channels=256,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
        self.bn_DMT3 = nn.BatchNorm2d(256, affine=True)
        self.relu_DMT3 = nn.ReLU()
        # IDMT3
        self.conv_IDMT3 = nn.Conv2d(in_channels=256,
                                    out_channels=1024,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)
        self.bn_IDMT3 = nn.BatchNorm2d(1024, affine=True)
        self.relu_IDMT3 = nn.ReLU()

        self.blockDMT3 = self.make_layer(Block_of_DMT3, 3)
        self.try_conv1 = nn.Conv2d(160, 80, 3, 1, 1)
        # self.try_bn = nn.BatchNorm2d(20)
        self.try_conv2 = nn.Conv2d(80, 12, 3, 1, 1)
示例#25
0
def mix_bands(img_1, img_2, net, label_1, label_2, dataset, switch_band):
    """
	Parameters:
		img_1: 1x3xHxW RGB image
		img_2: 1x3xHxW RGB image
	"""

    img_1 = torch.unsqueeze(img_1.clone(), 0)
    img_2 = torch.unsqueeze(img_2.clone(), 0)

    img_1 = postprocess(img_1, dataset)
    img_2 = postprocess(img_2, dataset)

    img_1_og = img_1.clone()
    img_2_og = img_2.clone()

    ifm = DWTInverse(mode='zero', wave='haar').cuda()
    xfm = DWTForward(J=1, mode='zero', wave='haar').cuda()

    fig, ax = plt.subplots(nrows=2, ncols=10, figsize=(20, 10))

    for i, img in enumerate([img_1, img_2]):
        LL, Y = xfm(img)
        LH, HL, HH = torch.unbind(Y[0], dim=2)

        ax[i, 0].imshow(LL[0].data.cpu().permute(1, 2, 0))
        ax[i, 1].imshow(10 * LH[0].data.cpu().permute(1, 2, 0) /
                        torch.max(LH[0].data.cpu().permute(1, 2, 0)))
        ax[i, 2].imshow(10 * HL[0].data.cpu().permute(1, 2, 0) /
                        torch.max(HL[0].data.cpu().permute(1, 2, 0)))
        ax[i, 3].imshow(10 * HH[0].data.cpu().permute(1, 2, 0) /
                        torch.max(HH[0].data.cpu().permute(1, 2, 0)))
        ax[i, 4].imshow(img[0].data.cpu().permute(1, 2, 0))

        ax[i, 0].set_title('LL' + '_' + str(i))
        ax[i, 1].set_title('LH' + '_' + str(i))
        ax[i, 2].set_title('HL' + '_' + str(i))
        ax[i, 3].set_title('HH' + '_' + str(i))
        ax[i, 4].set_title('normal_img' + '_' + str(i))

        ax[i, 0].set_yticks([], [])
        ax[i, 0].set_xticks([], [])
        ax[i, 1].set_yticks([], [])
        ax[i, 1].set_xticks([], [])
        ax[i, 2].set_yticks([], [])
        ax[i, 2].set_xticks([], [])
        ax[i, 3].set_yticks([], [])
        ax[i, 3].set_xticks([], [])
        ax[i, 4].set_yticks([], [])
        ax[i, 4].set_xticks([], [])

    # Reconstruct using new components
    LL_1, Y_1 = xfm(img_1)
    LH_1, HL_1, HH_1 = torch.unbind(Y_1[0], dim=2)

    LL_2, Y_2 = xfm(img_2)
    LH_2, HL_2, HH_2 = torch.unbind(Y_2[0], dim=2)

    if switch_band == 'll':
        img_1 = ifm((LL_2, [torch.stack((LH_1, HL_1, HH_1), 2)]))
        img_2 = ifm((LL_1, [torch.stack((LH_2, HL_2, HH_2), 2)]))

    elif switch_band == 'lh':
        img_1 = ifm((LL_1, [torch.stack((LH_2, HL_1, HH_1), 2)]))
        img_2 = ifm((LL_2, [torch.stack((LH_1, HL_2, HH_2), 2)]))

    elif switch_band == 'hl':
        img_1 = ifm((LL_1, [torch.stack((LH_1, HL_2, HH_1), 2)]))
        img_2 = ifm((LL_2, [torch.stack((LH_2, HL_1, HH_2), 2)]))

    elif switch_band == 'hh':
        img_1 = ifm((LL_1, [torch.stack((LH_1, HL_1, HH_2), 2)]))
        img_2 = ifm((LL_2, [torch.stack((LH_2, HL_2, HH_1), 2)]))

    elif switch_band == 'high':
        img_1 = ifm((LL_1, [torch.stack((LH_2, HL_2, HH_2), 2)]))
        img_2 = ifm((LL_2, [torch.stack((LH_1, HL_1, HH_1), 2)]))

    epsilon = 0.031
    eta = torch.clamp(img_1.data - img_1_og, -epsilon, epsilon)
    img_1 = img_1_og + eta
    img_1 = torch.clamp(img_1, 0, 1.0)

    eta = torch.clamp(img_2.data - img_2_og, -epsilon, epsilon)
    img_2 = img_2_og + eta
    img_2 = torch.clamp(img_2, 0, 1.0)

    for i, img in enumerate([img_1, img_2]):
        LL, Y = xfm(img)
        LH, HL, HH = torch.unbind(Y[0], dim=2)

        ax[i, 0 + 5].imshow(LL[0].data.cpu().permute(1, 2, 0))
        ax[i, 1 + 5].imshow(10 * LH[0].data.cpu().permute(1, 2, 0) /
                            torch.max(LH[0].data.cpu().permute(1, 2, 0)))
        ax[i, 2 + 5].imshow(10 * HL[0].data.cpu().permute(1, 2, 0) /
                            torch.max(HL[0].data.cpu().permute(1, 2, 0)))
        ax[i, 3 + 5].imshow(10 * HH[0].data.cpu().permute(1, 2, 0) /
                            torch.max(HH[0].data.cpu().permute(1, 2, 0)))
        ax[i, 4 + 5].imshow(img[0].data.cpu().permute(1, 2, 0))

        ax[i, 0 + 5].set_title('LL' + '_' + str(i))
        ax[i, 1 + 5].set_title('LH' + '_' + str(i))
        ax[i, 2 + 5].set_title('HL' + '_' + str(i))
        ax[i, 3 + 5].set_title('HH' + '_' + str(i))
        ax[i, 4 + 5].set_title('perturbed_img' + '_' + str(i))

        ax[i, 0 + 5].set_yticks([], [])
        ax[i, 0 + 5].set_xticks([], [])
        ax[i, 1 + 5].set_yticks([], [])
        ax[i, 1 + 5].set_xticks([], [])
        ax[i, 2 + 5].set_yticks([], [])
        ax[i, 2 + 5].set_xticks([], [])
        ax[i, 3 + 5].set_yticks([], [])
        ax[i, 3 + 5].set_xticks([], [])
        ax[i, 4 + 5].set_yticks([], [])
        ax[i, 4 + 5].set_xticks([], [])

    plt.savefig('./imgs/mix.png')

    print(['*'] * 20)

    resize = transforms.Resize((32, 32))

    LL_1, Y_1 = xfm(preprocess(img_1, dataset))
    LH_2, HL_2, HH_2 = torch.unbind(Y_1[0], dim=2)

    LL_2, Y_2 = xfm(preprocess(img_2, dataset))
    LH_1, HL_1, HH_1 = torch.unbind(Y_2[0], dim=2)

    print('image 1 label:\t', label_1)
    print('perturbed label:\t',
          torch.max(net(preprocess(img_1, dataset)), 1)[1])
    print('print LH_1 label:\t', torch.max(net(F.interpolate(LH_1, 32)), 1)[1])
    print('print HL_1 label:\t', torch.max(net(F.interpolate(HL_1, 32)), 1)[1])
    print('print HH_1 label:\t', torch.max(net(F.interpolate(HH_1, 32)), 1)[1])
    print('image 2 label:\t', label_2)
    print('perturbed label:\t',
          torch.max(net(preprocess(img_2, dataset)), 1)[1])
    print('print LH_2 label:\t', torch.max(net(F.interpolate(LH_2, 32)), 1)[1])
    print('print HL_2 label:\t', torch.max(net(F.interpolate(HL_2, 32)), 1)[1])
    print('print HH_2 label:\t', torch.max(net(F.interpolate(HH_2, 32)), 1)[1])