def getDistillData(teacher_model, dataloader, work_dir, num_batch): file_name = f'retinanet_distill.pth' print('generating ditilled data') hooks, hook_handles, bn_stats, distill_data = [], [], [], [] for n, m in teacher_model.backbone.named_modules(): if isinstance(m, nn.Conv2d): hook = output_hook() hooks.append(hook) hook_handles.append(m.register_forward_hook(hook.hook)) if isinstance(m, nn.BatchNorm2d): eps = 1e-6 bn_stats.append( (m.running_mean.detach().clone().flatten().cuda(), torch.sqrt(m.running_var + eps).detach().clone().flatten().cuda())) assert len(hooks) == len(bn_stats) teacher_model = MMDataParallel(teacher_model, device_ids=[0]) teacher_model.eval() logs = [] for i, gaussian_data in enumerate(dataloader): if i == num_batch: break logs.append({}) # Uniform initilizaition and normalize to 0 size = (1, 3, 800, 1216) gaussian_data['img'] = [ (torch.randint(high=255, size=size) - 128).float().cuda() / 5418.75 ] gaussian_data['img'][0].requires_grad = True optimizer = optim.Adam([gaussian_data['img'][0]], lr=0.1) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-4, verbose=False, patience=100) input_mean = torch.zeros(1, 3).cuda() input_std = torch.ones(1, 3).cuda() for it in range(10): teacher_model.zero_grad() optimizer.zero_grad() for hook in hooks: hook.clear() output = teacher_model(return_loss=False, rescale=True, **gaussian_data) mean_loss = 0 std_loss = 0 for cnt, (bn_stat, hook) in enumerate(zip(bn_stats, hooks)): tmp_output = hook.outputs bn_mean, bn_std = bn_stat[0], bn_stat[1] tmp_mean = torch.mean(tmp_output.view(tmp_output.size(0), tmp_output.size(1), -1), dim=2) tmp_std = torch.sqrt( torch.var(tmp_output.view(tmp_output.size(0), tmp_output.size(1), -1), dim=2) + eps) mean_loss += own_loss(bn_mean, tmp_mean) std_loss += own_loss(bn_std, tmp_std) # print(cnt, mean_loss.item(), std_loss.item()) tmp_mean = torch.mean(gaussian_data['img'][0].view(size[0], 3, -1), dim=2) tmp_std = torch.sqrt( torch.var(gaussian_data['img'][0].view(size[0], 3, -1), dim=2) + eps) mean_loss += own_loss(input_mean, tmp_mean) std_loss += own_loss(input_std, tmp_std) total_loss = mean_loss + std_loss total_loss.backward() # print(it, mean_loss.item(), std_loss.item()) logs[-1][it] = copy.deepcopy(total_loss.item()) optimizer.step() # scheduler.step(total_loss.item()) # if total_loss <= (len(hooks) + 1) * 10: # break gaussian_data['img'][0] = gaussian_data['img'][0].detach().clone() distill_data.append(gaussian_data) for handle in hook_handles: handle.remove() torch.save(distill_data, file_name) json.dump(logs, open(f'loss_log.json', 'w')) return distill_data
def attack_detector(args, model, cfg, dataset): print(str(datetime.now()) + ' - INFO - GPUs: ', cfg.gpus) print( str(datetime.now()) + ' - INFO - Imgs per GPU: ', cfg.data.imgs_per_gpu) print( str(datetime.now()) + ' - INFO - Workers per GPU: ', cfg.data.workers_per_gpu) print(str(datetime.now()) + ' - INFO - Momentum: ', args.momentum) print(str(datetime.now()) + ' - INFO - Epsilon: ', args.epsilon) infer_model = load_model(args) attack_loader = build_dataloader(dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, cfg.gpus, dist=False) model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() if args.clear_output: file_list = os.listdir(args.save_path) for f in file_list: if os.path.isdir(os.path.join(args.save_path, f)): shutil.rmtree(os.path.join(args.save_path, f)) else: os.remove(os.path.join(args.save_path, f)) class_names = infer_model.CLASSES num_of_classes = len(infer_model.CLASSES) max_batch = min(attack_loader.__len__(), args.max_attack_batches) pbar_outer = tqdm(total=max_batch) pbar_inner = tqdm(total=args.num_attack_iter) assert max_batch > 0 acc_before_attack = 0 acc_under_attack = 0 statistics = np.zeros(6) number_of_images = 0 dot_product = 0 conv_kernel = None with_mask = hasattr(model.module, 'mask_head') and model.module.mask_head is not None MAP_data = [[[None] * num_of_classes, [None] * num_of_classes], [[None] * num_of_classes, [None] * num_of_classes]] if args.kernel_size != 0: conv_kernel = conv_layer(args) for i, data in enumerate(attack_loader): epsilon = args.epsilon / max( data['img_meta'].data[0][0]['img_norm_cfg']['std']) if i >= max_batch: break raw_imgs = copy.deepcopy(data['img']) imgs = data['img'] for j in range(0, len(imgs.data)): imgs.data[j] = imgs.data[j].cuda() if args.visualize: if args.model_name == 'rpn_r50_fpn_1x': visualize_modification(args, infer_model, copy.deepcopy(imgs.data[j]), j, data['img_meta'].data[j], data['gt_bboxes'].data[j]) else: visualize_modification(args, infer_model, copy.deepcopy(imgs.data[j]), j, data['img_meta'].data[j], data['gt_bboxes'].data[j], data['gt_labels'].data[j]) imgs.data[j] = imgs.data[j].detach() if args.DIM: imgs.data[j].requires_grad = False else: imgs.data[j].requires_grad = True number_of_images += imgs.data[j].size()[0] pbar_inner.reset() last_update_direction = list(range(0, len(imgs.data))) for _ in range(args.num_attack_iter): if args.DIM: trans_imgs = copy.deepcopy(imgs) trans_img_meta = copy.deepcopy(data['img_meta']) trans_gt_bboxes = copy.deepcopy(data['gt_bboxes']) trans_gt_labels = copy.deepcopy(data['gt_labels']) trans_gt_masks = copy.deepcopy(data['gt_masks']) for j in range(0, len(trans_imgs.data)): trans_imgs.data[j].requires_grad = True original_size = trans_imgs.data[j].size() img_data = [] for k in range(0, original_size[0]): if torch.rand((1, 1))[0][0] < 0.7: resize_ratio = torch.rand((1, 1))[0][0] * 0.1 + 0.9 pad_size_x = original_size[2] - int( resize_ratio * original_size[2]) pad_size_y = original_size[3] - int( resize_ratio * original_size[3]) size_meta = trans_img_meta.data[j][k]['img_shape'] assert size_meta[2] == 3 trans_img_meta.data[j][k]['img_shape'] = (int( size_meta[0] * resize_ratio), int(size_meta[1] * resize_ratio), size_meta[2]) transform = transforms.Compose([ transforms.Scale( (int(resize_ratio * original_size[2]), int(resize_ratio * original_size[3]))), transforms.ToTensor(), ]) norm_cfg = trans_img_meta.data[j][k][ 'img_norm_cfg'] img_temp = trans_imgs.data[j][k].cpu().float() for channel in range(3): img_temp[channel] = img_temp[channel] * norm_cfg['std'][channel] + \ norm_cfg['mean'][channel] img_temp = transform(transforms.ToPILImage()( img_temp / 255.0)) * 255.0 for channel in range(3): img_temp[channel] = (img_temp[channel] - norm_cfg['mean'][channel]) \ / norm_cfg['std'][channel] img_temp = torch.cat( (img_temp, torch.zeros( (3, img_temp.size()[1], pad_size_y))), dim=2) img_temp = torch.cat( (img_temp, torch.zeros( (3, pad_size_x, img_temp.size()[2]))), dim=1) trans_gt_bboxes.data[j][k] *= resize_ratio img_data.append(img_temp) mask_data = [] for l in range( np.shape(trans_gt_masks.data[j][k])[0]): transform_mask = transforms.Compose([ transforms.Scale( (int(resize_ratio * original_size[2]), int(resize_ratio * original_size[3]))), transforms.ToTensor(), ]) mask_temp = transform_mask( transforms.ToPILImage()( trans_gt_masks.data[j][k][l])) mask_temp = torch.where( mask_temp > 0, torch.ones_like(mask_temp), torch.zeros_like(mask_temp)) mask_data.append(mask_temp) mask_data = torch.cat(tuple(mask_data), dim=0) mask_data = torch.cat( (mask_data, torch.zeros( (mask_data.size()[0], mask_data.size()[1], pad_size_y))), dim=2) mask_data = torch.cat( (mask_data, torch.zeros((mask_data.size()[0], pad_size_x, mask_data.size()[2]))), dim=1) trans_gt_masks.data[j][k] = mask_data.numpy( ).astype(np.uint8) else: img_data.append(trans_imgs.data[j][k].cpu()) trans_imgs.data[j] = torch.stack(tuple(img_data), dim=0).cuda().detach() trans_imgs.data[j].requires_grad = True else: trans_imgs = imgs trans_img_meta = data['img_meta'] trans_gt_bboxes = data['gt_bboxes'] trans_gt_labels = data['gt_labels'] trans_gt_masks = None if with_mask: trans_gt_masks = data['gt_masks'] if args.model_name == 'rpn_r50_fpn_1x': result = model(trans_imgs, trans_img_meta, return_loss=True, gt_bboxes=trans_gt_bboxes) elif with_mask: result = model(trans_imgs, trans_img_meta, return_loss=True, gt_bboxes=trans_gt_bboxes, gt_labels=trans_gt_labels, gt_masks=trans_gt_masks) else: result = model(trans_imgs, trans_img_meta, return_loss=True, gt_bboxes=trans_gt_bboxes, gt_labels=trans_gt_labels) loss = 0 for key in args.loss_keys: if type(result[key]) is list: for losses in result[key]: loss += losses.sum() else: loss += result[key].sum() loss.backward() for j in range(0, len(imgs.data)): if args.momentum == 0: update_direction = trans_imgs.data[j].grad if conv_kernel: update_direction = conv_kernel(update_direction) l1_per_img = torch.sum(torch.abs(update_direction), (1, 2, 3), keepdim=True) l1_per_img = l1_per_img.expand(update_direction.size()) update_direction = update_direction / l1_per_img imgs.data[j] = imgs.data[j] + epsilon / args.\ num_attack_iter * torch.sign(update_direction) else: if _ == 0: update_direction = trans_imgs.data[j].grad if conv_kernel: update_direction = conv_kernel(update_direction) l1_per_img = torch.sum(torch.abs(update_direction), (1, 2, 3), keepdim=True) l1_per_img = l1_per_img.expand(update_direction.size()) update_direction = update_direction / l1_per_img imgs.data[j] = imgs.data[j] + epsilon / args. \ num_attack_iter * torch.sign(update_direction) else: update_direction = trans_imgs.data[j].grad if conv_kernel: update_direction = conv_kernel(update_direction) l1_per_img = torch.sum(torch.abs(update_direction), (1, 2, 3), keepdim=True) l1_per_img = l1_per_img.expand(update_direction.size()) update_direction = update_direction / l1_per_img update_direction += args.momentum * last_update_direction[ j] imgs.data[j] = imgs.data[j] + epsilon / args. \ num_attack_iter * torch.sign(update_direction) imgs.data[j] = imgs.data[j].detach() if args.visualize: if args.model_name == 'rpn_r50_fpn_1x': visualize_modification(args, infer_model, copy.deepcopy(imgs.data[j]), j, data['img_meta'].data[j], data['gt_bboxes'].data[j]) else: visualize_modification(args, infer_model, copy.deepcopy(imgs.data[j]), j, data['img_meta'].data[j], data['gt_bboxes'].data[j], data['gt_labels'].data[j]) imgs.data[j].requires_grad = True if args.visualize and _ > 0: dot_product += torch.sum( update_direction.view(-1) / torch.norm(update_direction.view(-1)) * last_update_direction[j].view(-1) / torch.norm(last_update_direction[j].view(-1))) last_update_direction[j] = update_direction model.zero_grad() pbar_inner.update(1) for j in range(0, cfg.gpus): if args.model_name == 'rpn_r50_fpn_1x': t = ThreadingWithResult( visualize_all_images_plus_acc, args=(args, infer_model, imgs.data[j], raw_imgs.data[j], data['img_meta'].data[j], data['gt_bboxes'].data[j], MAP_data)) else: t = ThreadingWithResult( visualize_all_images_plus_acc, args=(args, infer_model, imgs.data[j], raw_imgs.data[j], data['img_meta'].data[j], data['gt_bboxes'].data[j], MAP_data, data['gt_labels'].data[j])) t.start() t.join() statistics_result = t.get_result() if statistics_result[0][0] >= 0: statistics += statistics_result[0] MAP_data = statistics_result[1] else: print("Error! Results were not fetched!") pbar_outer.update(1) pbar_outer.close() pbar_inner.close() if args.visualize and args.num_attack_iter > 1: dot_product /= (args.num_attack_iter - 1) * number_of_images / args.imgs_per_gpu print("average normalized dot product = ", dot_product) acc_before_attack /= max_batch acc_under_attack /= max_batch statistics /= number_of_images if args.neglect_raw_stat and args.experiment_index > args.resume_experiment: pass else: args.class_accuracy_before_attack = 100 * statistics[0] args.IoU_accuracy_before_attack = 100 * statistics[1] args.IoU_accuracy_before_attack2 = 100 * statistics[2] if MAP_data[0] is None: args.MAP_before_attack = 0 else: args.MAP_before_attack = eval_map_attack(MAP_data[0][0], MAP_data[0][1], len(class_names), scale_ranges=None, iou_thr=0.5, dataset=class_names, print_summary=True)[0] args.class_accuracy_under_attack = 100 * statistics[3] args.IoU_accuracy_under_attack = 100 * statistics[4] args.IoU_accuracy_under_attack2 = 100 * statistics[5] if MAP_data[1] is None: args.MAP_under_attack = 0 else: args.MAP_under_attack = eval_map_attack(MAP_data[1][0], MAP_data[1][1], len(class_names), scale_ranges=None, iou_thr=0.5, dataset=class_names, print_summary=True)[0] args.class_accuracy_decrease = args.class_accuracy_before_attack - args.class_accuracy_under_attack args.IoU_accuracy_decrease = args.IoU_accuracy_before_attack - args.IoU_accuracy_under_attack args.IoU_accuracy_decrease2 = args.IoU_accuracy_before_attack2 - args.IoU_accuracy_under_attack2 args.MAP_decrease = 100 * (args.MAP_before_attack - args.MAP_under_attack) print("Class & IoU accuracy before attack = %g %g" % (args.class_accuracy_before_attack, args.IoU_accuracy_before_attack)) print("Class & IoU accuracy under attack = %g %g" % (args.class_accuracy_under_attack, args.IoU_accuracy_under_attack)) print("Class & IoU accuracy decrease = %g %g" % (args.class_accuracy_decrease, args.IoU_accuracy_decrease)) print("MAP before attack = %g" % args.MAP_before_attack) print("MAP under attack = %g" % args.MAP_under_attack) print("MAP decrease = %g" % args.MAP_decrease) # torch.cuda.empty_cache() return args