Пример #1
0
 def forward(self, x):
     x = self.stage1(x)
     x = self.stage2(x).detach()
     x = self.stage3(x)
     x = self.stage4(x)
     x = torchutils.gap2d(x, keepdims=True)
     x = self.classifier(x)
     x = x.view(-1, 20)
     return x
Пример #2
0
 def forward(self, x):
     if x.size()[1] == 1:
         x = x.repeat(1, 3, 1, 1)
     x = self.stage1(x)
     x = self.stage2(x).detach()
     x = self.stage3(x)
     x = self.stage4(x)
     x = torchutils.gap2d(x, keepdims=True)
     x = self.classifier(x)
     # x = x.view(-1, 20)
     x = x.view(-1, 1)
     return x
Пример #3
0
    def forward(self, x):

        x = self.stage1(x)
        if self.pretrained:
            x = self.stage2(x).detach()
        else:  # Train from scratch for L8Biome
            x = self.stage2(x)

        x = self.stage3(x)
        x = self.stage4(x)

        x = torchutils.gap2d(x, keepdims=True)
        x = self.classifier(x)
        x = x.view(-1, self.n_classes)

        return x
def _work(process_id, model, dataset, args):
    databin = dataset[process_id]
    n_gpus = torch.cuda.device_count()
    data_loader = DataLoader(databin,
                             shuffle=False,
                             num_workers=args.num_workers // n_gpus,
                             pin_memory=True)
    print("dcpu", args.num_workers // n_gpus)
    cam_sizes = [[], [], [], []]  # scale 0,1,2,3
    with cuda.device(process_id):
        model.cuda()
        gcam = GradCAM(model=model, candidate_layers=[args.target_layer])
        for iter, pack in enumerate(data_loader):
            img_name = pack['name'][0]
            if os.path.exists(os.path.join(args.cam_out_dir,
                                           img_name + '.npy')):
                continue
            size = pack['size']
            strided_size = imutils.get_strided_size(size, 4)
            strided_up_size = imutils.get_strided_up_size(size, 16)
            outputs_cam = []
            n_classes = len(list(torch.nonzero(pack['label'][0])[:, 0]))

            for s_count, size_idx in enumerate([1, 0, 2, 3]):
                orig_img = pack['img'][size_idx].clone()
                for c_idx, c in enumerate(
                        list(torch.nonzero(pack['label'][0])[:, 0])):
                    pack['img'][size_idx] = orig_img
                    img_single = pack['img'][size_idx].detach()[
                        0]  # [:, 1]: flip

                    if size_idx != 1:
                        total_adv_iter = args.adv_iter
                    else:  # size_idx == 0
                        if args.adv_iter > 10:
                            total_adv_iter = args.adv_iter // 2
                            mul_for_scale = 2
                        elif args.adv_iter < 6:
                            total_adv_iter = args.adv_iter
                            mul_for_scale = 1
                        else:
                            total_adv_iter = 5
                            mul_for_scale = float(total_adv_iter) / 5

                    for it in range(total_adv_iter):
                        img_single.requires_grad = True

                        outputs = gcam.forward(
                            img_single.cuda(non_blocking=True))

                        if c_idx == 0 and it == 0:
                            cam_all_classes = torch.zeros([
                                n_classes, outputs.shape[2], outputs.shape[3]
                            ])

                        gcam.backward(ids=c)

                        regions = gcam.generate(target_layer=args.target_layer)
                        regions = regions[0] + regions[1].flip(-1)

                        if it == 0:
                            init_cam = regions.detach()

                        cam_all_classes[c_idx] += regions[0].data.cpu(
                        ) * mul_for_scale
                        logit = outputs
                        logit = F.relu(logit)
                        logit = torchutils.gap2d(logit, keepdims=True)[:, :, 0,
                                                                       0]

                        valid_cat = torch.nonzero(pack['label'][0])[:, 0]
                        logit_loss = -2 * (logit[:,
                                                 c]).sum() + torch.sum(logit)

                        expanded_mask = torch.zeros(regions.shape)
                        expanded_mask = add_discriminative(
                            expanded_mask, regions, score_th=args.score_th)

                        L_AD = torch.sum((torch.abs(regions - init_cam)) *
                                         expanded_mask.cuda())

                        loss = -logit_loss - L_AD * args.AD_coeff

                        model.zero_grad()
                        img_single.grad.zero_()
                        loss.backward()

                        data_grad = img_single.grad.data

                        perturbed_data = adv_climb(img_single,
                                                   args.AD_stepsize, data_grad)
                        img_single = perturbed_data.detach()

                outputs_cam.append(cam_all_classes)

            strided_cam = torch.sum(
                torch.stack([
                    F.interpolate(torch.unsqueeze(o, 0),
                                  strided_size,
                                  mode='bilinear',
                                  align_corners=False)[0] for o in outputs_cam
                ]), 0)
            highres_cam = [
                F.interpolate(torch.unsqueeze(o, 1),
                              strided_up_size,
                              mode='bilinear',
                              align_corners=False) for o in outputs_cam
            ]

            highres_cam = torch.sum(torch.stack(highres_cam, 0),
                                    0)[:, 0, :size[0], :size[1]]
            strided_cam /= F.adaptive_max_pool2d(strided_cam, (1, 1)) + 1e-5
            highres_cam /= F.adaptive_max_pool2d(highres_cam, (1, 1)) + 1e-5

            np.save(
                os.path.join(args.cam_out_dir, img_name + '.npy'), {
                    "keys": valid_cat,
                    "cam": strided_cam.cpu(),
                    "high_res": highres_cam.cpu().numpy()
                })