def set_rcam(self):
        print("Creating CAM for {}".format(self.args.model))
        if 'resnet' in str.lower(type(self.model).__name__):
            last_conv = 'layer4'
        else:
            print("Model not implemented. Setting rcam=False by default.")
            return

        self.weights = EvaluationMetrics(list(range(self.args.num_classes)))
        def hook_weights(module, input, output):
            weights.append(F.adaptive_max_pool2d(output, (1,1)))
        handle = self.model._modules.get(last_conv).register_forward_hook(hook_weights)

        train_loader, _ = get_loader(self.args.dataset,
            batch_size=1,
            num_workers=self.args.workers
        )
        for i, (image, label) in enumerate(train_loader):
            weights = []
            _ = self.model(to_var(image, volatile=True))
            weights = weights[0].squeeze()
            label = label.squeeze()[0]
            self.weights.update(label, weights)
            if (i+1)%1000 == 0:
                print("{:5.1f}% ({}/{})".format((i+1)/len(train_loader)*100, i+1, len(train_loader)))
        handle.remove()
    def eval(self):
        self.model.eval()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.val_loader):
            st = time.time()
            images = torch.autograd.Variable(images)
            labels = torch.autograd.Variable(labels)
            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images, labels)

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss), batch_size)
            eval_metrics.update('Acc', float(accuracy), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

        # Save best model
        if eval_metrics.avg['Acc'] > self.best_acc:
            self.save()
            self.logger.log("Saving best model: epoch={}".format(self.epoch))
            self.best_acc = eval_metrics.avg['Acc']
            self.maybe_delete_old_pth(log_path=self.log_path.as_posix(), max_to_keep=1)

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EVAL')
    def train_epoch(self):
        self.model.train()
        eval_metrics = EvaluationMetrics(['Loss', 'Acc', 'Time'])

        for i, (images, labels) in enumerate(self.train_loader):
            st = time.time()
            self.step += 1

            if self.args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            adv_images, _ = self.attacker.generate(images, labels)
            if self.args.cuda: adv_images = adv_images.cuda()

            images = torch.autograd.Variable(images)
            adv_images = torch.autograd.Variable(adv_images)
            labels = torch.autograd.Variable(labels)
            if self.args.half:
                images = images.half()
                adv_images = adv_images.half()

            outputs_clean, loss_clean = self.compute_loss(images, labels)
            outputs_adv, loss_adv = self.compute_loss(adv_images, labels)
            images = torch.cat([images, adv_images], dim=0)
            labels = torch.cat([labels, labels])

            outputs = torch.cat([outputs_clean, outputs_adv], dim=0)
            loss = self.alpha * loss_clean + (1 - self.alpha) * loss_adv

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            outputs = outputs.float()
            loss = loss.float()
            elapsed_time = time.time() - st

            _, preds = torch.max(outputs, 1)
            accuracy = (labels == preds.squeeze()).float().mean()

            batch_size = labels.size(0)
            eval_metrics.update('Loss', float(loss.data[0]), batch_size)
            eval_metrics.update('Acc', float(accuracy.data[0]), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

            if self.step % self.args.log_step == 0:
                self.logger.scalar_summary(eval_metrics.val, self.step, 'STEP')

        # Histogram of parameters
        for tag, p in self.model.named_parameters():
            tag = tag.split(".")
            tag = "Train_{}".format(tag[0]) + "/" + "/".join(tag[1:])
            try:
                self.logger.writer.add_histogram(tag, p.clone().cpu().data.numpy(), self.step)
                self.logger.writer.add_histogram(tag+'/grad', p.grad.clone().cpu().data.numpy(), self.step)
            except Exception as e:
                print("Check if variable {} is not used: {}".format(tag, e))

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EPOCH')
    def eval(self):
        self.model.eval()
        eval_metrics = EvaluationMetrics(['Loss', 'Time'])

        for i, (images, _) in enumerate(self.val_loader):
            st = time.time()
            images = torch.autograd.Variable(images)
            if self.args.cuda:
                images = images.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images)

            loss = loss.float()
            elapsed_time = time.time() - st

            batch_size = images.size(0)
            eval_metrics.update('Loss', float(loss.data[0]), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

        # Save best model
        if eval_metrics.avg['Loss'] < self.best_loss:
            self.save()
            self.logger.log("Saving best model: epoch={}".format(self.epoch))
            self.best_loss = eval_metrics.avg['Loss']
            self.maybe_delete_old_pth(log_path=self.log_path.as_posix(), max_to_keep=1)

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EVAL')
    def train_epoch(self):
        self.model.train()
        eval_metrics = EvaluationMetrics(['Loss', 'Time'])

        for i, (images, _) in enumerate(self.train_loader):
            st = time.time()
            self.step += 1
            images = torch.autograd.Variable(images)
            if self.args.cuda:
                images = images.cuda()
            if self.args.half: images = images.half()

            outputs, loss = self.compute_loss(images)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            loss = loss.float()
            elapsed_time = time.time() - st

            batch_size = images.size(0)
            eval_metrics.update('Loss', float(loss.data[0]), batch_size)
            eval_metrics.update('Time', elapsed_time, batch_size)

            if self.step % self.args.log_step == 0:
                self.logger.scalar_summary(eval_metrics.val, self.step, 'STEP')

        # Histogram of parameters
        for tag, p in self.model.named_parameters():
            tag = tag.split(".")
            tag = "Train_{}".format(tag[0]) + "/" + "/".join(tag[1:])
            try:
                self.logger.writer.add_histogram(tag, p.clone().cpu().data.numpy(), self.step)
                self.logger.writer.add_histogram(tag+'/grad', p.grad.clone().cpu().data.numpy(), self.step)
            except Exception as e:
                print("Check if variable {} is not used: {}".format(tag, e))

        self.logger.scalar_summary(eval_metrics.avg, self.step, 'EPOCH')
    def defend(self):
        self.model.eval()
        defense_scheme = getattr(defenses,
                                 self.args.defense)(self.model, self.args,
                                                    **self.kwargs)
        source = self.model
        if self.args.source is not None and (self.args.ckpt_name !=
                                             self.args.ckpt_src):
            target = self.args.ckpt_name
            self.args.model = self.args.source
            self.args.ckpt_name = self.args.ckpt_src
            source = get_model(self.args)
            self.logger.log("Transfer attack from {} -> {}".format(
                self.args.ckpt_src, target))
        attack_scheme = getattr(attacks, self.args.attack)(source, self.args,
                                                           **self.kwargs)

        eval_metrics = EvaluationMetrics(
            ['Test/Acc', 'Test/Top5', 'Test/Time'])
        eval_def_metrics = EvaluationMetrics(
            ['Def-Test/Acc', 'Def-Test/Top5', 'Def-Test/Time'])
        attack_metrics = EvaluationMetrics(
            ['Attack/Acc', 'Attack/Top5', 'Attack/Time'])
        defense_metrics = EvaluationMetrics(
            ['Defense/Acc', 'Defense/Top5', 'Defense/Time'])
        dist_metrics = EvaluationMetrics(['L0', 'L1', 'L2', 'Li'])

        for i, (images, labels) in enumerate(self.val_loader):
            self.step += 1
            if self.cuda:
                images = images.cuda()
                labels = labels.cuda()
            if self.args.half: images = images.half()

            # Inference
            st = time.time()
            outputs = self.model(self.to_var(images, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_metrics.update('Test/Acc', float(acc), labels.size(0))
            eval_metrics.update('Test/Top5', float(top5), labels.size(0))
            eval_metrics.update('Test/Time', time.time() - st, labels.size(0))

            # Attacker
            st = time.time()
            adv_images, adv_labels = attack_scheme.generate(images, labels)
            if isinstance(adv_images, Variable):
                adv_images = adv_images.data
            attack_metrics.update('Attack/Time',
                                  time.time() - st, labels.size(0))

            # Lp distance
            diff = torch.abs(
                denormalize(adv_images, self.args.dataset) -
                denormalize(images, self.args.dataset))
            L0 = torch.sum((torch.sum(diff, dim=1) > 1e-3).float().view(
                labels.size(0), -1),
                           dim=1).mean()
            diff = diff.view(labels.size(0), -1)
            L1 = torch.norm(diff, p=1, dim=1).mean()
            L2 = torch.norm(diff, p=2, dim=1).mean()
            Li = torch.max(diff, dim=1)[0].mean()
            dist_metrics.update('L0', float(L0), labels.size(0))
            dist_metrics.update('L1', float(L1), labels.size(0))
            dist_metrics.update('L2', float(L2), labels.size(0))
            dist_metrics.update('Li', float(Li), labels.size(0))

            # Defender
            st = time.time()
            def_images, def_labels = defense_scheme.generate(
                adv_images, adv_labels)
            if isinstance(
                    def_images, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images = def_images.data
            defense_metrics.update('Defense/Time',
                                   time.time() - st, labels.size(0))
            self.calc_stats('Attack', adv_images, images, adv_labels, labels,
                            attack_metrics)
            self.calc_stats('Defense', def_images, images, def_labels, labels,
                            defense_metrics)

            # Defense-Inference for shift of original image
            st = time.time()
            def_images_org, _ = defense_scheme.generate(images, labels)
            if isinstance(
                    def_images_org, Variable
            ):  # FIXME - Variable in Variable out for all methods
                def_images_org = def_images_org.data
            outputs = self.model(self.to_var(def_images_org, self.cuda, True))
            outputs = outputs.float()
            _, preds = torch.topk(outputs, 5)

            acc = (labels == preds.data[:, 0]).float().mean()
            top5 = torch.sum(
                (labels.unsqueeze(1).repeat(1, 5) == preds.data).float(),
                dim=1).mean()
            eval_def_metrics.update('Def-Test/Acc', float(acc), labels.size(0))
            eval_def_metrics.update('Def-Test/Top5', float(top5),
                                    labels.size(0))
            eval_def_metrics.update('Def-Test/Time',
                                    time.time() - st, labels.size(0))

            if self.step % self.args.log_step == 0 or self.step == len(
                    self.val_loader):
                self.logger.scalar_summary(eval_metrics.avg, self.step, 'TEST')
                self.logger.scalar_summary(eval_def_metrics.avg, self.step,
                                           'TEST')
                self.logger.scalar_summary(attack_metrics.avg, self.step,
                                           'ATTACK')
                self.logger.scalar_summary(defense_metrics.avg, self.step,
                                           'DEFENSE')
                self.logger.scalar_summary(dist_metrics.avg, self.step, 'DIST')

                defense_rate = eval_metrics.avg[
                    'Test/Acc'] - defense_metrics.avg['Defense/Acc']
                if eval_metrics.avg['Test/Acc'] - attack_metrics.avg[
                        'Attack/Acc']:
                    defense_rate /= eval_metrics.avg[
                        'Test/Acc'] - attack_metrics.avg['Attack/Acc']
                else:
                    defense_rate = 0
                defense_rate = 1 - defense_rate

                defense_top5 = eval_metrics.avg[
                    'Test/Top5'] - defense_metrics.avg['Defense/Top5']
                if eval_metrics.avg['Test/Top5'] - attack_metrics.avg[
                        'Attack/Top5']:
                    defense_top5 /= eval_metrics.avg[
                        'Test/Top5'] - attack_metrics.avg['Attack/Top5']
                else:
                    defense_top5 = 0
                defense_top5 = 1 - defense_top5

                self.logger.log(
                    "Defense Rate Top1: {:5.3f} | Defense Rate Top5: {:5.3f}".
                    format(defense_rate, defense_top5), 'DEFENSE')

            if self.step % self.args.img_log_step == 0:
                image_dict = {
                    'Original':
                    to_np(denormalize(images, self.args.dataset))[0],
                    'Attacked':
                    to_np(denormalize(adv_images, self.args.dataset))[0],
                    'Defensed':
                    to_np(denormalize(def_images, self.args.dataset))[0],
                    'Perturbation':
                    to_np(denormalize(images - adv_images,
                                      self.args.dataset))[0]
                }
                self.logger.image_summary(image_dict, self.step)
class PixelDeflection:
    def __init__(self, model, ndeflection=100, window=10, sigma=0.04, denoiser='wavelet',
                 rcam=True, args=None, **kwargs):
        """
        Most of the code is from https://github.com/iamaaditya/pixel-deflection
        """
        self.model = model
        self.ndeflection = ndeflection
        self.window = window
        self.sigma = sigma
        self.denoiser = denoiser
        self.args = args
        if rcam: self.set_rcam()

    def generate(self, images, labels):
        """
        Images (Tensor)
        Labels (Tensor)
        """
        self.original_shape = images[0].shape

        def_imgs = [self.generate_sample(image, label) for (image, label)
                    in zip(images, labels)]
        def_imgs = torch.stack(def_imgs)
        def_outputs = self.model(to_var(def_imgs, volatile=True))
        def_probs, def_labels = torch.max(def_outputs, 1)

        return def_imgs, def_labels

    def generate_sample(self, image, label):
        rcam = self.get_rcam(image)
        def_image = self.pixel_deflection(image, rcam, self.ndeflection, self.window)
        def_image = denormalize(def_image.unsqueeze(0), self.args.dataset).squeeze(0)
        def_image = np.transpose(def_image.cpu().numpy(), [1, 2, 0])
        def_image = self.denoise(self.denoiser, def_image, self.sigma)
        def_image = np.transpose(def_image, [2, 0, 1])
        def_image = torch.FloatTensor(def_image).cuda()
        def_image = normalize(def_image.unsqueeze(0), self.args.dataset).squeeze(0)

        return def_image

    @staticmethod
    def pixel_deflection(img, rcam, ndeflection, window):
        C, H, W = img.shape
        while ndeflection > 0:
            for c in range(C):
                x,y = np.random.randint(0,H-1), np.random.randint(0,W-1)
                if np.random.uniform() < rcam[x,y]:
                    continue

                while True: #this is to ensure that PD pixel lies inside the image
                    a,b = np.random.randint(-1*window,window), np.random.randint(-1*window,window)
                    if x+a < H and x+a > 0 and y+b < W and y+b > 0: break
                img[c,x,y] = img[c,x+a,y+b]
                ndeflection -= 1
        return img

    @staticmethod
    def denoise(denoiser_name, img, sigma):
        from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral, denoise_wavelet, denoise_nl_means, wiener)
        if denoiser_name == 'wavelet':
            """Input scale - [0, 1]
            """
            return denoise_wavelet(img, sigma=sigma, mode='soft', multichannel=True, convert2ycbcr=True, method='BayesShrink')
        elif denoiser_name == 'TVM':
            return denoise_tv_chambolle(img, multichannel=True)
        elif denoiser_name == 'bilateral':
            return denoise_bilateral(img, bins=1000, multichannel=True)
        elif denoiser_name == 'deconv':
            return wiener(img)
        elif denoiser_name == 'NLM':
            return denoise_nl_means(img, multichannel=True)
        else:
            raise Exception('Incorrect denoiser mentioned. Options: wavelet, TVM, bilateral, deconv, NLM')

    def set_rcam(self):
        print("Creating CAM for {}".format(self.args.model))
        if 'resnet' in str.lower(type(self.model).__name__):
            last_conv = 'layer4'
        else:
            print("Model not implemented. Setting rcam=False by default.")
            return

        self.weights = EvaluationMetrics(list(range(self.args.num_classes)))
        def hook_weights(module, input, output):
            weights.append(F.adaptive_max_pool2d(output, (1,1)))
        handle = self.model._modules.get(last_conv).register_forward_hook(hook_weights)

        train_loader, _ = get_loader(self.args.dataset,
            batch_size=1,
            num_workers=self.args.workers
        )
        for i, (image, label) in enumerate(train_loader):
            weights = []
            _ = self.model(to_var(image, volatile=True))
            weights = weights[0].squeeze()
            label = label.squeeze()[0]
            self.weights.update(label, weights)
            if (i+1)%1000 == 0:
                print("{:5.1f}% ({}/{})".format((i+1)/len(train_loader)*100, i+1, len(train_loader)))
        handle.remove()

    def get_rcam(self, image, k=1):
        size = image.shape[-2:]
        if not hasattr(self, 'weights'):
            return torch.zeros(size)
        if 'resnet' in str.lower(type(self.model).__name__):
            last_conv = 'layer4'
        else:
            return torch.zeros(size)

        features = []
        def hook_feature(module, input, output):
            features.append(output)
        handle = self.model._modules.get(last_conv).register_forward_hook(hook_feature)
        outputs = self.model(to_var(image.unsqueeze(0), volatile=True))
        outputs = to_np(outputs).squeeze()
        handle.remove()

        features = features[0]
        weights = self.weights.avg

        _, nc, h, w = features.shape
        cams = []
        for label in range(self.args.num_classes):
            cam = weights[label]@features.view(nc, h*w)
            cam = cam.view(h, w)
            cam = (cam - torch.min(cam))/(torch.max(cam) - torch.min(cam))
            cam = cam.view(1,1,*cam.shape)
            cams.append(F.upsample(cam, size, mode='bilinear'))
        rcam = 0
        for idx, label in enumerate(np.argsort(outputs)):
            if idx >= k:
                break
            else:
                rcam += cams[label]/float(2**(idx+1))
        rcam = (rcam - torch.min(rcam))/(torch.max(rcam) - torch.min(rcam))
        rcam = to_np(rcam).squeeze()

        return rcam