surrogate_model.cuda() surrogate_model.eval() attacker = SwitchNeg(args.dataset, args.batch_size, args.targeted, args.target_type, args.epsilon, args.norm, 0.0, 1.0, args.max_queries) for arch in archs: if args.attack_defense: save_result_path = args.exp_dir + "/{}_{}_result.json".format( arch, args.defense_model) else: save_result_path = args.exp_dir + "/{}_result.json".format(arch) if os.path.exists(save_result_path): continue log.info("Begin attack {} on {}, result will be saved to {}".format( arch, args.dataset, save_result_path)) if args.attack_defense: model = DefensiveModel(args.dataset, arch, no_grad=True, defense_model=args.defense_model) else: model = StandardModel(args.dataset, arch, no_grad=True) model.cuda() model.eval() attacker.attack_all_images(args, arch, model, surrogate_model, save_result_path) model.cpu() log.info("Save result of attacking {} done".format(arch))
def main(): parser = argparse.ArgumentParser( description='Square Attack Hyperparameters.') parser.add_argument('--norm', type=str, required=True, choices=['l2', 'linf']) parser.add_argument('--dataset', type=str, required=True) parser.add_argument('--exp-dir', default='logs', type=str, help='directory to save results and logs') parser.add_argument( '--gpu', type=str, required=True, help='GPU number. Multiple GPUs are possible for PT models.') parser.add_argument( '--p', type=float, default=0.05, help= 'Probability of changing a coordinate. Note: check the paper for the best values. ' 'Linf standard: 0.05, L2 standard: 0.1. But robust models require higher p.' ) parser.add_argument('--epsilon', type=float, help='Radius of the Lp ball.') parser.add_argument('--max_queries', type=int, default=10000) parser.add_argument( '--json-config', type=str, default= '/home1/machen/meta_perturbations_black_box_attack/configures/square_attack_conf.json', help='a configures file to be passed in instead of arguments') parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--targeted', action="store_true") parser.add_argument('--target_type', type=str, default='increment', choices=['random', 'least_likely', "increment"]) parser.add_argument('--attack_defense', action="store_true") parser.add_argument('--defense_model', type=str, default=None) parser.add_argument('--arch', default=None, type=str, help='network architecture') parser.add_argument('--test_archs', action="store_true") args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.json_config: # If a json file is given, use the JSON file as the base, and then update it with args defaults = json.load(open(args.json_config))[args.dataset][args.norm] arg_vars = vars(args) arg_vars = { k: arg_vars[k] for k in arg_vars if arg_vars[k] is not None } defaults.update(arg_vars) args = SimpleNamespace(**defaults) if args.targeted and args.dataset == "ImageNet": args.max_queries = 50000 args.exp_dir = os.path.join( args.exp_dir, get_exp_dir_name(args.dataset, args.norm, args.targeted, args.target_type, args)) os.makedirs(args.exp_dir, exist_ok=True) if args.test_archs: if args.attack_defense: log_file_path = osp.join( args.exp_dir, 'run_defense_{}.log'.format(args.defense_model)) else: log_file_path = osp.join(args.exp_dir, 'run.log') elif args.arch is not None: if args.attack_defense: log_file_path = osp.join( args.exp_dir, 'run_defense_{}_{}.log'.format(args.arch, args.defense_model)) else: log_file_path = osp.join(args.exp_dir, 'run_{}.log'.format(args.arch)) set_log_file(log_file_path) if args.test_archs: archs = [] if args.dataset == "CIFAR-10" or args.dataset == "CIFAR-100": for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/{}/checkpoint.pth.tar".format( PY_ROOT, args.dataset, arch) if os.path.exists(test_model_path): archs.append(arch) else: log.info(test_model_path + " does not exists!") elif args.dataset == "TinyImageNet": for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_list_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}*.pth.tar".format( root=PY_ROOT, dataset=args.dataset, arch=arch) test_model_path = list(glob.glob(test_model_list_path)) if test_model_path and os.path.exists(test_model_path[0]): archs.append(arch) else: for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_list_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/checkpoints/{}*.pth".format( PY_ROOT, args.dataset, arch) test_model_list_path = list(glob.glob(test_model_list_path)) if len(test_model_list_path ) == 0: # this arch does not exists in args.dataset continue archs.append(arch) else: assert args.arch is not None archs = [args.arch] args.arch = ", ".join(archs) log.info('Command line is: {}'.format(' '.join(sys.argv))) log.info("Log file is written in {}".format(log_file_path)) log.info('Called with args:') print_args(args) attacker = MetaSimulatorSquareAttack(args.dataset, args.batch_size, args.targeted, args.target_type, args.epsilon, args.norm, max_queries=args.max_queries) for arch in archs: if args.attack_defense: save_result_path = args.exp_dir + "/{}_{}_result.json".format( arch, args.defense_model) else: save_result_path = args.exp_dir + "/{}_result.json".format(arch) if os.path.exists(save_result_path): continue log.info("Begin attack {} on {}, result will be saved to {}".format( arch, args.dataset, save_result_path)) if args.attack_defense: model = DefensiveModel(args.dataset, arch, no_grad=True, defense_model=args.defense_model) else: model = StandardModel(args.dataset, arch, no_grad=True) model.cuda() model.eval() attacker.attack_all_images(args, arch, model, save_result_path)
def main(): args = get_args_parse() os.environ[ "TORCH_HOME"] = "/home1/machen/meta_perturbations_black_box_attack/train_pytorch_model/real_image_model/ImageNet-pretrained" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) target_str = "targeted" if args.targeted else "untargeted" json_conf = json.load(open( args.json_config))[args.dataset][target_str][args.norm] args = vars(args) args.update(json_conf) args = SimpleNamespace(**args) if args.targeted: if args.dataset == "ImageNet": args.max_queries = 50000 args.exp_dir = osp.join(args.exp_dir, get_exp_dir_name(args.dataset, args.surrogate_arch, args.norm, args.targeted, args.target_type, args)) # 随机产生一个目录用于实验 os.makedirs(args.exp_dir, exist_ok=True) if args.test_archs: if args.attack_defense: log_file_path = osp.join( args.exp_dir, 'run_defense_{}.log'.format(args.defense_model)) else: log_file_path = osp.join(args.exp_dir, 'run.log') elif args.arch is not None: if args.attack_defense: log_file_path = osp.join( args.exp_dir, 'run_defense_{}_{}.log'.format(args.arch, args.defense_model)) else: log_file_path = osp.join(args.exp_dir, 'run_{}.log'.format(args.arch)) set_log_file(log_file_path) archs = get_model_names(args) args.arch = ", ".join(archs) log.info('Command line is: {}'.format(' '.join(sys.argv))) log.info("Log file is written in {}".format(log_file_path)) log.info('Called with args:') print_args(args) layer = ['fc'] extractors = [] if args.surrogate_arch == "resnet50": resnet50 = models.resnet50(pretrained=True).eval() resnet50_extractor = ResNetFeatureExtractor(resnet50, layer).eval().cuda() extractors.append(resnet50_extractor) elif args.surrogate_arch == "resnet101": resnet101 = models.resnet101(pretrained=True).eval() resnet101_extractor = ResNetFeatureExtractor(resnet101, layer).eval().cuda() extractors.append(resnet101_extractor) elif args.surrogate_arch == "densenet121": densenet121 = models.densenet121(pretrained=True).eval() densenet121_extractor = DensenetFeatureExtractor(densenet121, layer).eval().cuda() extractors.append(densenet121_extractor) elif args.surrogate_arch == "densenet169": densenet169 = models.densenet169(pretrained=True).eval() densenet169_extractor = DensenetFeatureExtractor(densenet169, layer).eval().cuda() extractors.append(densenet169_extractor) directions_generator = TentativePerturbationGenerator(extractors, norm=args.norm, part_size=32, preprocess=True) attacker = VBADAttack(args, directions_generator) for arch in archs: if args.attack_defense: save_result_path = args.exp_dir + "/{}_{}_result.json".format( arch, args.defense_model) else: save_result_path = args.exp_dir + "/{}_result.json".format(arch) if os.path.exists(save_result_path): continue if args.attack_defense: model = DefensiveModel(args.dataset, arch, no_grad=True, defense_model=args.defense_model) else: model = StandardModel(args.dataset, arch, no_grad=True) model.cuda() model.eval() attacker.attack_all_images(args, arch, model, save_result_path) model.cpu()
def main(): args = get_parse_args() if args.json_config: # If a json file is given, use the JSON file as the base, and then update it with args defaults = json.load(open(args.json_config))[args.dataset][args.norm] arg_vars = vars(args) arg_vars = {k: arg_vars[k] for k in arg_vars if arg_vars[k] is not None} defaults.update(arg_vars) args = SimpleNamespace(**defaults) args.exp_dir = os.path.join(args.exp_dir, get_exp_dir_name(args.dataset, args.norm, args.targeted, args.target_type, args)) os.makedirs(args.exp_dir, exist_ok=True) if args.test_archs: if args.attack_defense: log_file_path = os.path.join(args.exp_dir, 'run_defense_{}.log'.format(args.defense_model)) else: log_file_path = os.path.join(args.exp_dir, 'run.log') elif args.arch is not None: if args.attack_defense: log_file_path = os.path.join(args.exp_dir, 'run_defense_{}_{}.log'.format(args.arch, args.defense_model)) else: log_file_path = os.path.join(args.exp_dir, 'run_{}.log'.format(args.arch)) set_log_file(log_file_path) if args.attack_defense: assert args.defense_model is not None attacker = FrankWolfeBlackBoxAttack(args, args.dataset, args.targeted, args.target_type, args.epsilon, args.norm, args.sensing, args.grad_est, args.delta, 0, 1, max_queries=args.max_queries) torch.backends.cudnn.deterministic = True random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.test_archs: archs = [] if args.dataset == "CIFAR-10" or args.dataset == "CIFAR-100": for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/{}/checkpoint.pth.tar".format( PY_ROOT, args.dataset, arch) if os.path.exists(test_model_path): archs.append(arch) else: log.info(test_model_path + " does not exists!") elif args.dataset == "TinyImageNet": for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_list_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}*.pth.tar".format( root=PY_ROOT, dataset=args.dataset, arch=arch) test_model_path = list(glob.glob(test_model_list_path)) if test_model_path and os.path.exists(test_model_path[0]): archs.append(arch) else: for arch in MODELS_TEST_STANDARD[args.dataset]: test_model_list_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/checkpoints/{}*.pth".format( PY_ROOT, args.dataset, arch) test_model_list_path = list(glob.glob(test_model_list_path)) if len(test_model_list_path) == 0: # this arch does not exists in args.dataset continue archs.append(arch) else: assert args.arch is not None archs = [args.arch] args.arch = ", ".join(archs) log.info('Command line is: {}'.format(' '.join(sys.argv))) log.info("Log file is written in {}".format(log_file_path)) log.info('Called with args:') print_args(args) for arch in archs: if args.attack_defense: save_result_path = args.exp_dir + "/{}_{}_result.json".format(arch, args.defense_model) else: save_result_path = args.exp_dir + "/{}_result.json".format(arch) if os.path.exists(save_result_path): continue log.info("Begin attack {} on {}, result will be saved to {}".format(arch, args.dataset, save_result_path)) if args.attack_defense: model = DefensiveModel(args.dataset, arch, no_grad=True, defense_model=args.defense_model) else: model = StandardModel(args.dataset, arch, no_grad=True) model.cuda() model.eval() attacker.attack_all_images(args, arch, model, save_result_path) model.cpu()
def attack_all_images(self, args, arch, tmp_dump_path, result_dump_path): # subset_pos用于回调函数汇报汇总统计结果 if args.attack_defense: model = DefensiveModel(args.dataset, arch, no_grad=True, defense_model=args.defense_model) else: model = StandardModel(args.dataset, arch, no_grad=True) model.cuda() model.eval() # 带有缩减功能的,攻击成功的图片自动删除掉 for data_idx, data_tuple in enumerate(self.dataset_loader): if os.path.exists(tmp_dump_path): with open(tmp_dump_path, "r") as file_obj: json_content = json.load(file_obj) resume_batch_idx = int(json_content["batch_idx"]) # resume for key in [ 'query_all', 'correct_all', 'not_done_all', 'success_all', 'success_query_all' ]: if key in json_content: setattr( self, key, torch.from_numpy(np.asarray( json_content[key])).float()) if data_idx < resume_batch_idx: # resume continue if args.dataset == "ImageNet": if model.input_size[-1] >= 299: images, true_labels = data_tuple[1], data_tuple[2] else: images, true_labels = data_tuple[0], data_tuple[2] else: images, true_labels = data_tuple[0], data_tuple[1] if images.size(-1) != model.input_size[-1]: images = F.interpolate(images, size=model.input_size[-1], mode='bilinear', align_corners=True) # skip_batch_index_list = np.nonzero(np.asarray(chunk_skip_indexes[data_idx]))[0].tolist() selected = torch.arange( data_idx * args.batch_size, min((data_idx + 1) * args.batch_size, self.total_images)) # 选择这个batch的所有图片的index img_idx_to_batch_idx = ImageIdxToOrigBatchIdx(args.batch_size) # if len(skip_batch_index_list) > 0: # FIXME 一种用bandits失败的结果就不判断的加速手段,可以删去 # for skip_index in skip_batch_index_list: # pos = selected[skip_index] # self.query_all[pos] = args.max_queries # self.correct_all[pos] = 0 # self.not_done_all[pos] = 1 # self.success_all[pos] = 0 # 让其定义为分类失败 # self.success_query_all[pos] = 0 # self.not_done_loss_all[pos] = 1.0 # self.not_done_prob_all[pos] = 1.0 # images, true_labels = self.delete_tensor_by_index_list(skip_batch_index_list, images, true_labels) # img_idx_to_batch_idx.del_by_index_list(skip_batch_index_list) images, true_labels = images.cuda(), true_labels.cuda() first_finetune = True finetune_queue = FinetuneQueue(args.batch_size, args.meta_seq_len, img_idx_to_batch_idx) prior_size = model.input_size[ -1] if not args.tiling else args.tile_size assert args.tiling == (args.dataset == "ImageNet") if args.tiling: upsampler = Upsample(size=(model.input_size[-2], model.input_size[-1])) else: upsampler = lambda x: x with torch.no_grad(): logit = model(images) pred = logit.argmax(dim=1) query = torch.zeros(images.size(0)).cuda() correct = pred.eq(true_labels).float() # shape = (batch_size,) not_done = correct.clone() # shape = (batch_size,) if args.targeted: if args.target_type == 'random': target_labels = torch.randint( low=0, high=CLASS_NUM[args.dataset], size=true_labels.size()).long().cuda() invalid_target_index = target_labels.eq(true_labels) while invalid_target_index.sum().item() > 0: target_labels[invalid_target_index] = torch.randint( low=0, high=logit.shape[1], size=target_labels[invalid_target_index].shape ).long().cuda() invalid_target_index = target_labels.eq(true_labels) elif args.target_type == 'least_likely': target_labels = logit.argmin(dim=1) elif args.target_type == "increment": target_labels = torch.fmod(true_labels + 1, CLASS_NUM[args.dataset]) else: raise NotImplementedError('Unknown target_type: {}'.format( args.target_type)) else: target_labels = None prior = torch.zeros(images.size(0), IN_CHANNELS[args.dataset], prior_size, prior_size).cuda() prior_step = self.gd_prior_step if args.norm == 'l2' else self.eg_prior_step image_step = self.l2_image_step if args.norm == 'l2' else self.linf_step proj_step = self.l2_proj_step if args.norm == 'l2' else self.linf_proj_step # 调用proj_maker返回的是一个函数 criterion = self.cw_loss if args.data_loss == "cw" else self.xent_loss adv_images = images.clone() for step_index in range(1, args.max_queries + 1): # Create noise for exporation, estimate the gradient, and take a PGD step dim = prior.nelement() / images.size( 0) # nelement() --> total number of elements exp_noise = args.exploration * torch.randn_like(prior) / ( dim**0.5 ) # parameterizes the exploration to be done around the prior exp_noise = exp_noise.cuda() q1 = upsampler( prior + exp_noise ) # 这就是Finite Difference算法, prior相当于论文里的v,这个prior也会更新,把梯度累积上去 q2 = upsampler( prior - exp_noise) # prior 相当于累积的更新量,用这个更新量,再去修改image,就会变得非常准 # Loss points for finite difference estimator q1_images = adv_images + args.fd_eta * q1 / self.norm(q1) q2_images = adv_images + args.fd_eta * q2 / self.norm(q2) predict_by_target_model = False if (step_index <= args.warm_up_steps or ( step_index - args.warm_up_steps) % args.meta_predict_steps == 0) \ or (len(np.where(not_done.detach().cpu().numpy().astype(np.int32) == 1)[0]) / float(args.batch_size) <= args.notdone_threshold): log.info("predict from target model") predict_by_target_model = True with torch.no_grad(): q1_logits = model(q1_images) q2_logits = model(q2_images) q1_logits = q1_logits / torch.norm( q1_logits, p=2, dim=-1, keepdim=True) # 加入normalize q2_logits = q2_logits / torch.norm( q2_logits, p=2, dim=-1, keepdim=True) finetune_queue.append(q1_images.detach(), q2_images.detach(), q1_logits.detach(), q2_logits.detach()) if (step_index >= args.warm_up_steps and len( np.where(not_done.detach().cpu().numpy().astype( np.int32) == 1)[0]) / float(args.batch_size) > args.notdone_threshold): q1_images_seq, q2_images_seq, q1_logits_seq, q2_logits_seq = finetune_queue.stack_history_track( ) finetune_times = args.finetune_times if first_finetune else random.randint( 3, 5) # FIXME self.meta_finetuner.finetune( q1_images_seq, q2_images_seq, q1_logits_seq, q2_logits_seq, finetune_times, first_finetune, img_idx_to_batch_idx) first_finetune = False else: with torch.no_grad(): q1_logits, q2_logits = self.meta_finetuner.predict( q1_images, q2_images, img_idx_to_batch_idx) q1_logits = q1_logits / torch.norm( q1_logits, p=2, dim=-1, keepdim=True) q2_logits = q2_logits / torch.norm( q2_logits, p=2, dim=-1, keepdim=True) l1 = criterion(q1_logits, true_labels, target_labels) l2 = criterion(q2_logits, true_labels, target_labels) # Finite differences estimate of directional derivative est_deriv = (l1 - l2) / (args.fd_eta * args.exploration ) # 方向导数 , l1和l2是loss # 2-query gradient estimate est_grad = est_deriv.view(-1, 1, 1, 1) * exp_noise # B, C, H, W, # Update the prior with the estimated gradient prior = prior_step( prior, est_grad, args.online_lr) # 注意,修正的是prior,这就是bandit算法的精髓 grad = upsampler(prior) # prior相当于梯度 ## Update the image: adv_images = image_step( adv_images, grad * correct.view(-1, 1, 1, 1), # 注意correct也是删减过的 args.image_lr) # prior放大后相当于累积的更新量,可以用来更新 adv_images = proj_step(images, args.epsilon, adv_images) adv_images = torch.clamp(adv_images, 0, 1) with torch.no_grad(): adv_logit = model(adv_images) # adv_pred = adv_logit.argmax(dim=1) adv_prob = F.softmax(adv_logit, dim=1) adv_loss = criterion(adv_logit, true_labels, target_labels) ## Continue query count if predict_by_target_model: query = query + 2 * not_done if args.targeted: not_done = not_done * ( 1 - adv_pred.eq(target_labels).float() ).float() # not_done初始化为 correct, shape = (batch_size,) else: not_done = not_done * adv_pred.eq( true_labels).float() # 只要是跟原始label相等的,就还需要query,还没有成功 success = (1 - not_done) * correct success_query = success * query not_done_loss = adv_loss * not_done not_done_prob = adv_prob[torch.arange(adv_images.size(0)), true_labels] * not_done log.info('Attacking image {} - {} / {}, step {}'.format( data_idx * args.batch_size, (data_idx + 1) * args.batch_size, self.total_images, step_index)) log.info(' not_done: {:.4f}'.format( len( np.where(not_done.detach().cpu().numpy().astype( np.int32) == 1)[0]) / float(args.batch_size))) log.info(' fd_scalar: {:.9f}'.format( (l1 - l2).mean().item())) if success.sum().item() > 0: log.info(' mean_query: {:.4f}'.format( success_query[success.byte()].mean().item())) log.info(' median_query: {:.4f}'.format( success_query[success.byte()].median().item())) if not_done.sum().item() > 0: log.info(' not_done_loss: {:.4f}'.format( not_done_loss[not_done.byte()].mean().item())) log.info(' not_done_prob: {:.4f}'.format( not_done_prob[not_done.byte()].mean().item())) not_done_np = not_done.detach().cpu().numpy().astype(np.int32) done_img_idx_list = np.where(not_done_np == 0)[0].tolist() delete_all = False if done_img_idx_list: for skip_index in done_img_idx_list: # 两次循环,第一次循环先汇报出去,第二次循环删除 batch_idx = img_idx_to_batch_idx[skip_index] pos = selected[batch_idx].item() # 先汇报被删减的值self.query_all for key in [ 'query', 'correct', 'not_done', 'success', 'success_query', 'not_done_loss', 'not_done_prob' ]: value_all = getattr(self, key + "_all") value = eval(key)[skip_index].item() value_all[pos] = value images, adv_images, prior, query, true_labels, target_labels, correct, not_done =\ self.delete_tensor_by_index_list(done_img_idx_list, images, adv_images, prior, query, true_labels, target_labels, correct, not_done) img_idx_to_batch_idx.del_by_index_list(done_img_idx_list) delete_all = images is None if delete_all: break # report to all stats the rest unsuccess for key in [ 'query', 'correct', 'not_done', 'success', 'success_query', 'not_done_loss', 'not_done_prob' ]: for img_idx, batch_idx in img_idx_to_batch_idx.proj_dict.items( ): pos = selected[batch_idx].item() value_all = getattr(self, key + "_all") value = eval(key)[img_idx].item() value_all[ pos] = value # 由于value_all是全部图片都放在一个数组里,当前batch选择出来 img_idx_to_batch_idx.proj_dict.clear() tmp_info_dict = { "batch_idx": data_idx + 1, "batch_size": args.batch_size } for key in [ 'query_all', 'correct_all', 'not_done_all', 'success_all', 'success_query_all' ]: value_all = getattr(self, key).detach().cpu().numpy().tolist() tmp_info_dict[key] = value_all with open(tmp_dump_path, "w") as result_file_obj: json.dump(tmp_info_dict, result_file_obj, sort_keys=True) query_all_ = self.query_all.detach().cpu().numpy().astype(np.int32) not_done_all_ = self.not_done_all.detach().cpu().numpy().astype( np.int32) # log.info('{} is attacked finished ({} images)'.format(arch, self.total_images)) # log.info(' avg correct: {:.4f}'.format(self.correct_all.mean().item())) # log.info(' avg not_done: {:.4f}'.format(self.not_done_all.mean().item())) # 有多少图没做完 # if self.success_all.sum().item() > 0: # log.info( # ' avg mean_query: {:.4f}'.format(self.success_query_all[self.success_all.byte()].mean().item())) # log.info( # ' avg median_query: {:.4f}'.format(self.success_query_all[self.success_all.byte()].median().item())) # log.info(' max query: {}'.format(self.success_query_all[self.success_all.byte()].max().item())) # if self.not_done_all.sum().item() > 0: # log.info( # ' avg not_done_loss: {:.4f}'.format(self.not_done_loss_all[self.not_done_all.byte()].mean().item())) # log.info( # ' avg not_done_prob: {:.4f}'.format(self.not_done_prob_all[self.not_done_all.byte()].mean().item())) log.info('Saving results to {}'.format(result_dump_path)) meta_info_dict = { "avg_correct": self.correct_all.mean().item(), "avg_not_done": self.not_done_all[self.correct_all.byte()].mean().item(), "mean_query": self.success_query_all[self.success_all.byte()].mean().item(), "median_query": self.success_query_all[self.success_all.byte()].median().item(), "max_query": self.success_query_all[self.success_all.byte()].max().item(), "correct_all": self.correct_all.detach().cpu().numpy().astype(np.int32).tolist(), "not_done_all": self.not_done_all.detach().cpu().numpy().astype(np.int32).tolist(), "query_all": self.query_all.detach().cpu().numpy().astype(np.int32).tolist(), "not_done_loss": self.not_done_loss_all[self.not_done_all.byte()].mean().item(), "not_done_prob": self.not_done_prob_all[self.not_done_all.byte()].mean().item(), "args": vars(args) } with open(result_dump_path, "w") as result_file_obj: json.dump(meta_info_dict, result_file_obj, sort_keys=True) log.info("done, write stats info to {}".format(result_dump_path)) self.query_all.fill_(0) self.correct_all.fill_(0) self.not_done_all.fill_(0) self.success_all.fill_(0) self.success_query_all.fill_(0) self.not_done_loss_all.fill_(0) self.not_done_prob_all.fill_(0) model.cpu()