コード例 #1
0
def test():
    from PIL import Image
    import numpy as np
    import os, pycat
    im = Variable(torch.Tensor(
        np.expand_dims(
            np.transpose(
                np.array(
                    Image.open(
                        os.path.join(os.path.dirname(__file__), 'test.jpg'))),
                (2, 0, 1)), 0) / 255. * 2 - 1.),
                  requires_grad=False)
    print 'Original'
    pycat.show(im[0].data.numpy())
    print
    for pres in [1., 0.5, 0.1]:
        print 'Mask strength =', pres
        for e in xrange(5):
            m = Variable(torch.Tensor(1, 3, im.size(2),
                                      im.size(3)).fill_(pres),
                         requires_grad=True)
            res = apply_mask(im, m)
            pycat.show(res[0].data.numpy())
    s = torch.sum(res)
    s.backward()
    print torch.sum(m.grad)
コード例 #2
0
def test(cuda=True):
    f = get_pretrained_saliency_fn(cuda=cuda)
    import os
    import pycat
    import time
    from sal.utils.mask import apply_mask
    # simply load an image
    ims = load_image_as_variable(os.path.join(os.path.dirname(__file__), 'sal/utils/test2.jpg'), cuda=cuda)

    zebra_mask = f(ims, 340)  # 340 is a zebra
    elefant_mask = f(ims, [386])  # 386 is an elefant (check sal/datasets/imagenet_synset.py for more)


    print 'You should see a zebra'
    pycat.show(apply_mask(ims, zebra_mask, boolean=False).cpu()[0].data.numpy()*128+128, auto_normalize=False)
    print 'You should see an elefant'
    pycat.show(apply_mask(ims, elefant_mask, boolean=False).cpu()[0].data.numpy()*128+128, auto_normalize=False)

    print 'Testing speed with CUDA_ENABLED =', cuda
    print 'Please wait...'
    t = time.time()
    for e in xrange(20 if cuda else 2):
        f(np.random.randn(32, 3, 224, 224), np.random.uniform(0, 100, size=(32,)).astype(np.int), 6)
    print 'Images per second:', 32. * (20 if cuda else 2) / (time.time()-t)
    print 'You should expect ~200 images per second on a GPU (Titan XP) and 2.5 images per second on a CPU. '
コード例 #3
0
def test():
    BS = 64
    SAMP = 20
    dts = get_val_dataset()
    loader = get_loader(dts, batch_size=BS)
    i = 0
    t = time.time()
    for ims, labs in loader:
        i += 1
        if not i % 20:
            print(
                "min",
                torch.min(ims),
                "max",
                torch.max(ims),
                "var",
                torch.var(ims),
                "mean",
                torch.mean(ims),
            )
            print("Images per second:", SAMP * BS / (time.time() - t))
            pycat.show(ims[0].numpy())
            t = time.time()
        if i == 100:
            break
コード例 #4
0
 def f(s):
     if isinstance(imgs_names, str):
         im = s.pt_store[imgs_names][ith]
     else:
         cands = tuple(auto_norm(s.pt_store[i][ith]) for i in imgs_names)
         im = np.concatenate(cands, 1)
     print()
     pycat.show(im)
     print()
コード例 #5
0
def phase2_visualise(s):
    pt = s.pt_store
    orig = auto_norm(pt['images'][0])
    mask = auto_norm(pt['masks'][0] * 255, auto_normalize=False)
    preserved = auto_norm(pt['preserved'][0])
    destroyed = auto_norm(pt['destroyed'][0])
    print
    print 'Target (%s) = %s' % (GREEN_STR % 'REAL' if pt['is_real_label'][0]
                                else RED_STR % 'FAKE!',
                                dts.CLASS_ID_TO_NAME[pt['targets'][0]])
    final = np.concatenate((orig, mask, preserved, destroyed), axis=1)
    pycat.show(final)
コード例 #6
0
    def get_saliency_maps(self, _images, _targets, iterations=None, show=False):
        """ returns saliency maps.
         Params
         _images - input images of shape (C, H, W) or (N, C, H, W) if in batch. Can be either a numpy array, a Tensor or a Variable
         _targets - class ids to be masked. Can be either an int or an array with N integers. Again can be either a numpy array, a Tensor or a Variable

         returns a Variable of shape (N, 1, H, W) with one saliency maps for each input image.
         """
        _images, _targets = (
            to_batch_variable(_images, 4, self.cuda).float(),
            to_batch_variable(_targets, 1, self.cuda).long(),
        )

        if iterations is None:
            iterations = self.default_iterations

        if self.cuda:
            _mask = nn.Parameter(
                torch.Tensor(
                    _images.size(0), 2, self.mask_resolution, self.mask_resolution
                )
                .fill_(0.5)
                .cuda()
            )
        else:
            _mask = nn.Parameter(
                torch.Tensor(
                    _images.size(0), 2, self.mask_resolution, self.mask_resolution
                ).fill_(0.5)
            )
        optim = torch_optim.SGD([_mask], 0.1, 0.9, nesterov=True)
        # optim = torch_optim.Adam([_mask], 0.2)

        for iteration in range(iterations):
            # _mask.data.clamp_(0., 1.)
            optim.zero_grad()

            a = torch.abs(_mask[:, 0, :, :])
            b = torch.abs(_mask[:, 1, :, :])
            _mask_ = torch.unsqueeze(a / (a + b + 0.001), dim=1)

            total_loss = self.saliency_loss_calc.get_loss(
                _images, _targets, _mask_, pt_store=PT
            )

            total_loss.backward()

            optim.step()
            if show:
                pycat.show(PT["masks"][0] * 255, auto_normalize=False)
                pycat.show(PT["preserved"][0])
        return PT.masks
コード例 #7
0
def phase2_visualise(s):
    pt = s.pt_store
    orig = auto_norm(pt["images"][0])
    mask = auto_norm(pt["masks"][0] * 255, auto_normalize=False)
    preserved = auto_norm(pt["preserved"][0])
    destroyed = auto_norm(pt["destroyed"][0])
    print()
    print("Target (%s) = %s" % (
        GREEN_STR % "REAL" if pt["is_real_label"][0] else RED_STR % "FAKE!",
        dts.CLASS_ID_TO_NAME[pt["targets"][0]],
    ))
    final = np.concatenate((orig, mask, preserved, destroyed), axis=1)
    pycat.show(final)
コード例 #8
0
def test():
    BS = 64
    SAMP = 20
    dts = get_train_dataset()
    loader = get_loader(dts, batch_size=BS)
    i = 0
    t = time.time()
    for ims, labs in loader:
        i+=1
        if not i%20:
            print "Images per second:", SAMP*BS/(time.time()-t)
            pycat.show(ims[0].numpy())
            t = time.time()
        if i==100:
            break
コード例 #9
0
def test():
    from PIL import Image
    import os, pycat

    im = Variable(
        torch.Tensor(
            np.expand_dims(
                np.transpose(
                    np.array(
                        Image.open(
                            os.path.join(os.path.dirname(__file__),
                                         "test.jpg"))),
                    (2, 0, 1),
                ),
                0,
            ) / 255.0),
        requires_grad=True,
    )
    g = gaussian_blur(im)
    print("Original")
    pycat.show(im[0].data.numpy())
    print("Blurred version")
    pycat.show(g[0].data.numpy())
    print(
        "Image gradient over blurred sum (should be white in the middle + turning darker at the edges)"
    )
    l = torch.sum(g)
    l.backward()
    gr = im.grad[0].data.numpy()
    assert (np.mean(gr) > 0.9 and np.mean(np.flip(gr, 1) - gr) < 1e-6
            and np.mean(np.flip(gr, 2) - gr) < 1e-6)
    pycat.show(gr)