Пример #1
0
def getDistillData(teacher_model, dataloader, work_dir, num_batch):
    file_name = f'retinanet_distill.pth'

    print('generating ditilled data')

    hooks, hook_handles, bn_stats, distill_data = [], [], [], []

    for n, m in teacher_model.backbone.named_modules():
        if isinstance(m, nn.Conv2d):
            hook = output_hook()
            hooks.append(hook)
            hook_handles.append(m.register_forward_hook(hook.hook))
        if isinstance(m, nn.BatchNorm2d):
            eps = 1e-6
            bn_stats.append(
                (m.running_mean.detach().clone().flatten().cuda(),
                 torch.sqrt(m.running_var +
                            eps).detach().clone().flatten().cuda()))

    assert len(hooks) == len(bn_stats)
    teacher_model = MMDataParallel(teacher_model, device_ids=[0])
    teacher_model.eval()

    logs = []
    for i, gaussian_data in enumerate(dataloader):
        if i == num_batch:
            break
        logs.append({})
        # Uniform initilizaition and normalize to 0
        size = (1, 3, 800, 1216)
        gaussian_data['img'] = [
            (torch.randint(high=255, size=size) - 128).float().cuda() / 5418.75
        ]

        gaussian_data['img'][0].requires_grad = True
        optimizer = optim.Adam([gaussian_data['img'][0]], lr=0.1)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-4, verbose=False, patience=100)
        input_mean = torch.zeros(1, 3).cuda()
        input_std = torch.ones(1, 3).cuda()

        for it in range(10):
            teacher_model.zero_grad()
            optimizer.zero_grad()
            for hook in hooks:
                hook.clear()
            output = teacher_model(return_loss=False,
                                   rescale=True,
                                   **gaussian_data)
            mean_loss = 0
            std_loss = 0
            for cnt, (bn_stat, hook) in enumerate(zip(bn_stats, hooks)):
                tmp_output = hook.outputs
                bn_mean, bn_std = bn_stat[0], bn_stat[1]
                tmp_mean = torch.mean(tmp_output.view(tmp_output.size(0),
                                                      tmp_output.size(1), -1),
                                      dim=2)
                tmp_std = torch.sqrt(
                    torch.var(tmp_output.view(tmp_output.size(0),
                                              tmp_output.size(1), -1),
                              dim=2) + eps)
                mean_loss += own_loss(bn_mean, tmp_mean)
                std_loss += own_loss(bn_std, tmp_std)
                # print(cnt, mean_loss.item(), std_loss.item())
            tmp_mean = torch.mean(gaussian_data['img'][0].view(size[0], 3, -1),
                                  dim=2)
            tmp_std = torch.sqrt(
                torch.var(gaussian_data['img'][0].view(size[0], 3, -1), dim=2)
                + eps)
            mean_loss += own_loss(input_mean, tmp_mean)
            std_loss += own_loss(input_std, tmp_std)
            total_loss = mean_loss + std_loss
            total_loss.backward()
            # print(it, mean_loss.item(), std_loss.item())
            logs[-1][it] = copy.deepcopy(total_loss.item())
            optimizer.step()
            # scheduler.step(total_loss.item())
            # if total_loss <= (len(hooks) + 1) * 10:
            # 	break
        gaussian_data['img'][0] = gaussian_data['img'][0].detach().clone()
        distill_data.append(gaussian_data)

    for handle in hook_handles:
        handle.remove()

    torch.save(distill_data, file_name)
    json.dump(logs, open(f'loss_log.json', 'w'))

    return distill_data
Пример #2
0
def attack_detector(args, model, cfg, dataset):
    print(str(datetime.now()) + ' - INFO - GPUs: ', cfg.gpus)
    print(
        str(datetime.now()) + ' - INFO - Imgs per GPU: ',
        cfg.data.imgs_per_gpu)
    print(
        str(datetime.now()) + ' - INFO - Workers per GPU: ',
        cfg.data.workers_per_gpu)
    print(str(datetime.now()) + ' - INFO - Momentum: ', args.momentum)
    print(str(datetime.now()) + ' - INFO - Epsilon: ', args.epsilon)
    infer_model = load_model(args)
    attack_loader = build_dataloader(dataset,
                                     cfg.data.imgs_per_gpu,
                                     cfg.data.workers_per_gpu,
                                     cfg.gpus,
                                     dist=False)
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
    if args.clear_output:
        file_list = os.listdir(args.save_path)
        for f in file_list:
            if os.path.isdir(os.path.join(args.save_path, f)):
                shutil.rmtree(os.path.join(args.save_path, f))
            else:
                os.remove(os.path.join(args.save_path, f))
    class_names = infer_model.CLASSES
    num_of_classes = len(infer_model.CLASSES)
    max_batch = min(attack_loader.__len__(), args.max_attack_batches)
    pbar_outer = tqdm(total=max_batch)
    pbar_inner = tqdm(total=args.num_attack_iter)
    assert max_batch > 0
    acc_before_attack = 0
    acc_under_attack = 0
    statistics = np.zeros(6)
    number_of_images = 0
    dot_product = 0
    conv_kernel = None
    with_mask = hasattr(model.module,
                        'mask_head') and model.module.mask_head is not None
    MAP_data = [[[None] * num_of_classes, [None] * num_of_classes],
                [[None] * num_of_classes, [None] * num_of_classes]]
    if args.kernel_size != 0:
        conv_kernel = conv_layer(args)
    for i, data in enumerate(attack_loader):
        epsilon = args.epsilon / max(
            data['img_meta'].data[0][0]['img_norm_cfg']['std'])
        if i >= max_batch:
            break
        raw_imgs = copy.deepcopy(data['img'])
        imgs = data['img']
        for j in range(0, len(imgs.data)):
            imgs.data[j] = imgs.data[j].cuda()
            if args.visualize:
                if args.model_name == 'rpn_r50_fpn_1x':
                    visualize_modification(args, infer_model,
                                           copy.deepcopy(imgs.data[j]), j,
                                           data['img_meta'].data[j],
                                           data['gt_bboxes'].data[j])
                else:
                    visualize_modification(args, infer_model,
                                           copy.deepcopy(imgs.data[j]), j,
                                           data['img_meta'].data[j],
                                           data['gt_bboxes'].data[j],
                                           data['gt_labels'].data[j])
            imgs.data[j] = imgs.data[j].detach()
            if args.DIM:
                imgs.data[j].requires_grad = False
            else:
                imgs.data[j].requires_grad = True
            number_of_images += imgs.data[j].size()[0]
        pbar_inner.reset()
        last_update_direction = list(range(0, len(imgs.data)))
        for _ in range(args.num_attack_iter):
            if args.DIM:
                trans_imgs = copy.deepcopy(imgs)
                trans_img_meta = copy.deepcopy(data['img_meta'])
                trans_gt_bboxes = copy.deepcopy(data['gt_bboxes'])
                trans_gt_labels = copy.deepcopy(data['gt_labels'])
                trans_gt_masks = copy.deepcopy(data['gt_masks'])
                for j in range(0, len(trans_imgs.data)):
                    trans_imgs.data[j].requires_grad = True
                    original_size = trans_imgs.data[j].size()
                    img_data = []
                    for k in range(0, original_size[0]):
                        if torch.rand((1, 1))[0][0] < 0.7:
                            resize_ratio = torch.rand((1, 1))[0][0] * 0.1 + 0.9
                            pad_size_x = original_size[2] - int(
                                resize_ratio * original_size[2])
                            pad_size_y = original_size[3] - int(
                                resize_ratio * original_size[3])
                            size_meta = trans_img_meta.data[j][k]['img_shape']
                            assert size_meta[2] == 3
                            trans_img_meta.data[j][k]['img_shape'] = (int(
                                size_meta[0] *
                                resize_ratio), int(size_meta[1] *
                                                   resize_ratio), size_meta[2])
                            transform = transforms.Compose([
                                transforms.Scale(
                                    (int(resize_ratio * original_size[2]),
                                     int(resize_ratio * original_size[3]))),
                                transforms.ToTensor(),
                            ])
                            norm_cfg = trans_img_meta.data[j][k][
                                'img_norm_cfg']
                            img_temp = trans_imgs.data[j][k].cpu().float()
                            for channel in range(3):
                                img_temp[channel] = img_temp[channel] * norm_cfg['std'][channel] + \
                                                    norm_cfg['mean'][channel]
                            img_temp = transform(transforms.ToPILImage()(
                                img_temp / 255.0)) * 255.0
                            for channel in range(3):
                                img_temp[channel] = (img_temp[channel] - norm_cfg['mean'][channel]) \
                                                    / norm_cfg['std'][channel]
                            img_temp = torch.cat(
                                (img_temp,
                                 torch.zeros(
                                     (3, img_temp.size()[1], pad_size_y))),
                                dim=2)
                            img_temp = torch.cat(
                                (img_temp,
                                 torch.zeros(
                                     (3, pad_size_x, img_temp.size()[2]))),
                                dim=1)
                            trans_gt_bboxes.data[j][k] *= resize_ratio
                            img_data.append(img_temp)
                            mask_data = []
                            for l in range(
                                    np.shape(trans_gt_masks.data[j][k])[0]):
                                transform_mask = transforms.Compose([
                                    transforms.Scale(
                                        (int(resize_ratio * original_size[2]),
                                         int(resize_ratio *
                                             original_size[3]))),
                                    transforms.ToTensor(),
                                ])
                                mask_temp = transform_mask(
                                    transforms.ToPILImage()(
                                        trans_gt_masks.data[j][k][l]))
                                mask_temp = torch.where(
                                    mask_temp > 0, torch.ones_like(mask_temp),
                                    torch.zeros_like(mask_temp))
                                mask_data.append(mask_temp)
                            mask_data = torch.cat(tuple(mask_data), dim=0)
                            mask_data = torch.cat(
                                (mask_data,
                                 torch.zeros(
                                     (mask_data.size()[0], mask_data.size()[1],
                                      pad_size_y))),
                                dim=2)
                            mask_data = torch.cat(
                                (mask_data,
                                 torch.zeros((mask_data.size()[0], pad_size_x,
                                              mask_data.size()[2]))),
                                dim=1)
                            trans_gt_masks.data[j][k] = mask_data.numpy(
                            ).astype(np.uint8)
                        else:
                            img_data.append(trans_imgs.data[j][k].cpu())
                    trans_imgs.data[j] = torch.stack(tuple(img_data),
                                                     dim=0).cuda().detach()
                    trans_imgs.data[j].requires_grad = True
            else:
                trans_imgs = imgs
                trans_img_meta = data['img_meta']
                trans_gt_bboxes = data['gt_bboxes']
                trans_gt_labels = data['gt_labels']
                trans_gt_masks = None
                if with_mask:
                    trans_gt_masks = data['gt_masks']
            if args.model_name == 'rpn_r50_fpn_1x':
                result = model(trans_imgs,
                               trans_img_meta,
                               return_loss=True,
                               gt_bboxes=trans_gt_bboxes)
            elif with_mask:
                result = model(trans_imgs,
                               trans_img_meta,
                               return_loss=True,
                               gt_bboxes=trans_gt_bboxes,
                               gt_labels=trans_gt_labels,
                               gt_masks=trans_gt_masks)
            else:
                result = model(trans_imgs,
                               trans_img_meta,
                               return_loss=True,
                               gt_bboxes=trans_gt_bboxes,
                               gt_labels=trans_gt_labels)
            loss = 0
            for key in args.loss_keys:
                if type(result[key]) is list:
                    for losses in result[key]:
                        loss += losses.sum()
                else:
                    loss += result[key].sum()
            loss.backward()
            for j in range(0, len(imgs.data)):
                if args.momentum == 0:
                    update_direction = trans_imgs.data[j].grad
                    if conv_kernel:
                        update_direction = conv_kernel(update_direction)
                    l1_per_img = torch.sum(torch.abs(update_direction),
                                           (1, 2, 3),
                                           keepdim=True)
                    l1_per_img = l1_per_img.expand(update_direction.size())
                    update_direction = update_direction / l1_per_img
                    imgs.data[j] = imgs.data[j] + epsilon / args.\
                        num_attack_iter * torch.sign(update_direction)
                else:
                    if _ == 0:
                        update_direction = trans_imgs.data[j].grad
                        if conv_kernel:
                            update_direction = conv_kernel(update_direction)
                        l1_per_img = torch.sum(torch.abs(update_direction),
                                               (1, 2, 3),
                                               keepdim=True)
                        l1_per_img = l1_per_img.expand(update_direction.size())
                        update_direction = update_direction / l1_per_img
                        imgs.data[j] = imgs.data[j] + epsilon / args. \
                            num_attack_iter * torch.sign(update_direction)
                    else:
                        update_direction = trans_imgs.data[j].grad
                        if conv_kernel:
                            update_direction = conv_kernel(update_direction)
                        l1_per_img = torch.sum(torch.abs(update_direction),
                                               (1, 2, 3),
                                               keepdim=True)
                        l1_per_img = l1_per_img.expand(update_direction.size())
                        update_direction = update_direction / l1_per_img
                        update_direction += args.momentum * last_update_direction[
                            j]
                        imgs.data[j] = imgs.data[j] + epsilon / args. \
                            num_attack_iter * torch.sign(update_direction)
                imgs.data[j] = imgs.data[j].detach()
                if args.visualize:
                    if args.model_name == 'rpn_r50_fpn_1x':
                        visualize_modification(args, infer_model,
                                               copy.deepcopy(imgs.data[j]), j,
                                               data['img_meta'].data[j],
                                               data['gt_bboxes'].data[j])
                    else:
                        visualize_modification(args, infer_model,
                                               copy.deepcopy(imgs.data[j]), j,
                                               data['img_meta'].data[j],
                                               data['gt_bboxes'].data[j],
                                               data['gt_labels'].data[j])
                imgs.data[j].requires_grad = True
                if args.visualize and _ > 0:
                    dot_product += torch.sum(
                        update_direction.view(-1) /
                        torch.norm(update_direction.view(-1)) *
                        last_update_direction[j].view(-1) /
                        torch.norm(last_update_direction[j].view(-1)))
                last_update_direction[j] = update_direction
            model.zero_grad()
            pbar_inner.update(1)
        for j in range(0, cfg.gpus):
            if args.model_name == 'rpn_r50_fpn_1x':
                t = ThreadingWithResult(
                    visualize_all_images_plus_acc,
                    args=(args, infer_model, imgs.data[j], raw_imgs.data[j],
                          data['img_meta'].data[j], data['gt_bboxes'].data[j],
                          MAP_data))
            else:
                t = ThreadingWithResult(
                    visualize_all_images_plus_acc,
                    args=(args, infer_model, imgs.data[j], raw_imgs.data[j],
                          data['img_meta'].data[j], data['gt_bboxes'].data[j],
                          MAP_data, data['gt_labels'].data[j]))
            t.start()
            t.join()
            statistics_result = t.get_result()
            if statistics_result[0][0] >= 0:
                statistics += statistics_result[0]
                MAP_data = statistics_result[1]
            else:
                print("Error! Results were not fetched!")
        pbar_outer.update(1)
    pbar_outer.close()
    pbar_inner.close()
    if args.visualize and args.num_attack_iter > 1:
        dot_product /= (args.num_attack_iter -
                        1) * number_of_images / args.imgs_per_gpu
        print("average normalized dot product = ", dot_product)
    acc_before_attack /= max_batch
    acc_under_attack /= max_batch
    statistics /= number_of_images

    if args.neglect_raw_stat and args.experiment_index > args.resume_experiment:
        pass
    else:
        args.class_accuracy_before_attack = 100 * statistics[0]
        args.IoU_accuracy_before_attack = 100 * statistics[1]
        args.IoU_accuracy_before_attack2 = 100 * statistics[2]
        if MAP_data[0] is None:
            args.MAP_before_attack = 0
        else:
            args.MAP_before_attack = eval_map_attack(MAP_data[0][0],
                                                     MAP_data[0][1],
                                                     len(class_names),
                                                     scale_ranges=None,
                                                     iou_thr=0.5,
                                                     dataset=class_names,
                                                     print_summary=True)[0]
    args.class_accuracy_under_attack = 100 * statistics[3]
    args.IoU_accuracy_under_attack = 100 * statistics[4]
    args.IoU_accuracy_under_attack2 = 100 * statistics[5]
    if MAP_data[1] is None:
        args.MAP_under_attack = 0
    else:
        args.MAP_under_attack = eval_map_attack(MAP_data[1][0],
                                                MAP_data[1][1],
                                                len(class_names),
                                                scale_ranges=None,
                                                iou_thr=0.5,
                                                dataset=class_names,
                                                print_summary=True)[0]
    args.class_accuracy_decrease = args.class_accuracy_before_attack - args.class_accuracy_under_attack
    args.IoU_accuracy_decrease = args.IoU_accuracy_before_attack - args.IoU_accuracy_under_attack
    args.IoU_accuracy_decrease2 = args.IoU_accuracy_before_attack2 - args.IoU_accuracy_under_attack2
    args.MAP_decrease = 100 * (args.MAP_before_attack - args.MAP_under_attack)
    print("Class & IoU accuracy before attack = %g %g" %
          (args.class_accuracy_before_attack, args.IoU_accuracy_before_attack))
    print("Class & IoU accuracy under attack = %g %g" %
          (args.class_accuracy_under_attack, args.IoU_accuracy_under_attack))
    print("Class & IoU accuracy decrease = %g %g" %
          (args.class_accuracy_decrease, args.IoU_accuracy_decrease))
    print("MAP before attack = %g" % args.MAP_before_attack)
    print("MAP under attack = %g" % args.MAP_under_attack)
    print("MAP decrease = %g" % args.MAP_decrease)
    # torch.cuda.empty_cache()
    return args