def main(args):
	if args["model_type"] == "normal":
		load_robust = False
	else:
		load_robust = True
	# Set TF random seed to improve reproducibility
	# tf.set_random_seed(args["seed"])
	data  = MNIST()
	if not hasattr(K, "tf"):
		raise RuntimeError("This tutorial requires keras to be configured"
						" to use the TensorFlow backend.")

	if keras.backend.image_dim_ordering() != 'tf':
		keras.backend.set_image_dim_ordering('tf')
		print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
			"'th', temporarily setting to 'tf'")

	# Create TF session and set as Keras backend session
	sess = tf.Session()
	keras.backend.set_session(sess)

	# # Get MNIST test data
	x_test, y_test = MNIST().test_data, MNIST().test_labels

	all_trans_rate_ls=[] # store transfer rate of all seeds (evaluated on independent test set)
	remain_trans_rate_ls = [] # store transfer rate of remaining seeds, used only in local model fine-tuning

	# Define input TF placeholder
	x = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
	y = tf.placeholder(tf.float32, shape=(None, 10))
	class_num = 10
	image_size = 28
	num_channels = 1

	################## load target model ###############################	
	if not load_robust:
		target_model_name = 'mnist'
		target_model = mnist_models(sess,0,use_softmax=True,x = x, y = y,\
			load_existing=True,model_name=target_model_name)
		accuracy = target_model.calcu_acc(x_test,y_test,batch_size = 1000)
		print('Test accuracy of target model {}: {:.4f}'.format(target_model_name,accuracy))
	else:
		if args["robust_type"] == "madry":
			target_model_name = 'madry_robust'
			model_dir = "MNIST_models/madry_robust_model" 
			target_model = MadryModel(sess,model_dir = model_dir,bias = 0.5)
		elif args["robust_type"] == "zico":
			target_model_name = 'zico_robust'
			model_dir = "MNIST_models/zico_robust_model/mnist.pth"  
			target_model = ZicoModel(model_dir = model_dir,bias = 0.5)
		elif args["robust_type"] == "percy":
			target_model_name = 'percy_robust'
			model_dir = "MNIST_models/percy_robust_model/sdp_train/" 
			target_model = PercyModel(sess,model_dir = model_dir,bias = 0.5)
		else:
			raise NotImplementedError
		pred_class = target_model.pred_class(x_test)
		print('Test accuracy of target robust model :{:.4f}'.format(np.sum(pred_class == np.argmax(y_test,axis = 1))/len(x_test))) 
	###################### end of loading target model ################## 

	if args["attack_method"] == "autozoom":
		# setup the autoencoders
		# codec = CODEC(image_size, num_channels, args["compress_mode"], use_tanh=False)
		codec = 0
		args["img_resize"] = 14
		codec_dir = 'MNIST_models/mnist_autoencoder/'
		encoder = load_model(codec_dir + 'whole_mnist_encoder.h5')
		decoder = load_model(codec_dir + 'whole_mnist_decoder.h5')

		encode_img = encoder.predict(data.test_data[100:101])
		decode_img = decoder.predict(encode_img)
		diff_img = (decode_img - data.test_data[100:101])
		diff_mse = np.mean(diff_img.reshape(-1)**2)

		print("[Info][AE] MSE:{:.4f}".format(diff_mse))
		encode_img = encoder.predict(data.test_data[0:1])
		decode_img = decoder.predict(encode_img)
		diff_img = (decode_img - data.test_data[0:1])
		diff_mse = np.mean(diff_img.reshape(-1)**2)
		print("[Info][AE] MSE:{:.4f}".format(diff_mse))

	#load local models or define the architecture
	local_model_names = ['modelA','modelB','modelC', 'modelD', 'modelE', 'modelF']

	if args["attack_method"] == "autozoom":
		# define black-box model graph 
		blackbox_attack = AutoZOOM(sess, target_model, args, decoder, codec,
				num_channels,image_size,class_num) 

	nb_imgs = args["num_img"]
	# local attack related params
	clip_min = -0.5
	clip_max = 0.5
	li_eps = args["cost_threshold"]

	load_existing = True # if true, we start with pretrained local models, otherwise, start with random model
	with_local = args["with_local"] # if true, hybrid attack, otherwise, only baseline attacks
	
	if args["no_tune_local"]:
		stop_fine_tune_flag = True
	else:
		stop_fine_tune_flag = False

	if with_local:
		if load_existing:
			loc_adv = 'adv_with_tune'
		if args["no_tune_local"]:
			loc_adv = 'adv_no_tune'
	else:
		loc_adv = 'orig'
	
	# target type
	if args["attack_type"] == "targeted":
		is_targeted = True
	else:
		is_targeted = False

	### attack pretrained models or build models using substitute training ####   
	# local model training parameters
	sub_epochs = args["nb_epochs_sub"] # epcohs for sub model training or fine-tuning
	use_loc_adv_thres = args["use_loc_adv_thres"] # threshold for transfer attack success rate, only use local AEs when transfer rate is higher than this threshold
	use_loc_adv_flag = True # since we are starting from pretrained models, hybrid attack always uses local aes. If we start from random model, local ae is not reliable

	fine_tune_freq = args["fine_tune_freq"] # fine-tune the model every K images to save total model training time
	
	# store the attack input files (e.g., original image, target class)
	input_file_prefix = os.path.join(args["local_path"],target_model_name,
												args["attack_type"])
	os.system("mkdir -p {}".format(input_file_prefix)) 
	# generate the attack seeds and target class
	target_ys_one_hot,orig_images,target_ys,orig_labels,_, x_trans_inputs = \
	generate_attack_inputs(sess,target_model,x_test,y_test,class_num,nb_imgs,\
		load_imgs=args["load_imgs"],load_robust=load_robust,\
			file_path = input_file_prefix)

	# images are generated based on seed (1234), reassign 
	# the random to improve reproducibility
	random.seed(args["seed"])
	np.random.seed(args["seed"])
	tf.set_random_seed(args["seed"])

	start_points = np.copy(orig_images) # starting points for the gradient attacks, can be local AEs
	# attack statistical info
	dist_record = np.zeros(len(orig_labels),dtype = float)  
	query_num_vec = np.zeros(len(orig_labels), dtype=int)   
	success_vec = np.zeros(len(orig_labels),dtype=bool)
	adv_classes = np.zeros(len(orig_labels), dtype=int)

	# define local model graphs and do the initial training
	model_types = [0,1,2,5] # select the architecture of local models
	local_model_ls = []
	pred_ls = []
	sss = 0
	# Prepare callbacks for model saving
	save_dir = 'MNIST_models/normal_models/' 
	callbacks_ls = []

	attacked_flag = np.zeros(len(orig_labels),dtype = bool)
	################## load local models #####################
	if with_local:
		load_model_names = []
		for model_type in model_types:
			model_name = local_model_names[model_type]
			load_model_names.append(model_name)
			loc_model = mnist_models(sess,model_type,use_softmax=True,x = x, y = y,\
			load_existing=load_existing,model_name=model_name)
			pred_ls.append(loc_model.predictions)
			local_model_ls.append(loc_model)
			opt = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
			loc_model.model.compile(loss='categorical_crossentropy',
										optimizer=opt,
										metrics=['accuracy'])
			
			if args["no_save_model"]:
				if not load_existing:
					loc_model.model.fit(orig_images, orig_labels,
						batch_size=args["train_batch_size"],
						epochs=sub_epochs,
						verbose=0,
						validation_data=(x_test, y_test),
						shuffle = True) 
			else:
				if load_existing:
					filepath = save_dir + model_name + '_pretrained.h5' 
				else:
					filepath = save_dir + model_name + '.h5' 
				checkpoint = ModelCheckpoint(filepath=filepath,
											monitor='val_acc',
											verbose=0,
											save_best_only=True)

				# earlystopping = keras.callbacks.EarlyStopping(monitor='val_acc',\
				# min_delta=0, patience=3, verbose=0, mode='auto', baseline=None, \
				# restore_best_weights=False)
				# callbacks = [checkpoint, earlystopping]
				callbacks = [checkpoint]
				callbacks_ls.append(callbacks)
				if not load_existing:
					print("Train on %d data and validate on %d data" % (len(orig_labels),len(y_test)))
					loc_model.model.fit(orig_images, orig_labels,
						batch_size=args["train_batch_size"],
						epochs=sub_epochs,
						verbose=0,
						validation_data=(x_test, y_test),
						shuffle = True,
						callbacks = callbacks)  
			scores = loc_model.model.evaluate(x_test, y_test, verbose=0)
			print('Test accuracy of model {}: {:.4f}'.format(model_name,scores[1]))
			sss += 1         
		####################### end of load local models ########################
		# Define Attack Graph of PGD attack 
		local_attack_graph = LinfPGDAttack(local_model_ls,
							epsilon = li_eps, 
							k = 100,
							a = 0.01,
							random_start = True,
							loss_func = args["loss_function"],
							targeted = is_targeted,
							x = x,
							y = y)

		# store local info: a directory of random number generator is included below for the convenience in averaging the attack results
		local_info_file_prefix = os.path.join(args["local_path"],target_model_name,
												args["attack_type"],str(args["seed"]))
		os.system("mkdir -p {}".format(local_info_file_prefix)) 
		if not args["load_local_AEs"]:
			# check do the transfer check to obtain local adversarial samples
			if is_targeted:
				all_trans_rate, pred_labs, local_aes,pgd_cnt_mat, max_loss, min_loss,\
             ave_loss, max_gap, min_gap, ave_gap = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
				target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = 500,\
				attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
			else:
				all_trans_rate, pred_labs, local_aes,pgd_cnt_mat,max_loss, min_loss,\
             ave_loss, max_gap, min_gap, ave_gap = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
				orig_labels[np.logical_not(attacked_flag)],eval_batch_size = 500,\
				attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)

			# calculate local adv loss used for scheduling experiments...
			if is_targeted:
				adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
				target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
			else:
				adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
				orig_labels,targeted=is_targeted,load_robust=load_robust)
			
			# calculate orig img loss for scheduling experiments
			if is_targeted:
				orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
				target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
			else:
				orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
				orig_labels,targeted=is_targeted,load_robust=load_robust)

			if not args["force_tune_baseline"]:
				# save local aes
				np.save(local_info_file_prefix+'/local_aes.npy',local_aes)
				# store local info of local aes and original seeds: used for scheduling seeds in batch attacks
				np.savetxt(local_info_file_prefix+'/pgd_cnt_mat.txt',pgd_cnt_mat)
				np.savetxt(local_info_file_prefix+'/orig_img_loss.txt',orig_img_loss)
				np.savetxt(local_info_file_prefix+'/adv_img_loss.txt',adv_img_loss)
				np.savetxt(local_info_file_prefix+'/ave_gap.txt',ave_gap)
		else:
			local_aes = np.load(local_info_file_prefix+'/local_aes.npy')
			if is_targeted:
				tmp_labels = target_ys_one_hot
			else:
				tmp_labels = orig_labels
			pred_labs = np.argmax(target_model.predict_prob(np.array(local_aes)),axis=1)
			print('correct number',np.sum(pred_labs == np.argmax(tmp_labels,axis=1)))
			all_trans_rate = accuracy_score(np.argmax(tmp_labels,axis=1), pred_labs)
		if not is_targeted:
			all_trans_rate = 1 - all_trans_rate
		print('** Transfer Rate: **' + str(all_trans_rate))  

		# independent test set for checking transferability: for experiment purpose and does not count for query numbers
		if is_targeted:
			ind_all_trans_rate, pred_labs, _,_, _, _,\
             _, _, _, _ = local_attack_in_batches(sess,x_trans_inputs,target_ys_one_hot,eval_batch_size = 500,\
			attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
		else:
			ind_all_trans_rate, pred_labs, _,_,_, _,\
             _, _, _, _ = local_attack_in_batches(sess,x_trans_inputs,orig_labels,eval_batch_size = 500,\
			attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
		
		# record the queries spent by quering the local samples
		query_num_vec[np.logical_not(attacked_flag)] += 1
		if not is_targeted:
			ind_all_trans_rate = 1 - ind_all_trans_rate
		print('** (Independent Set) Transfer Rate: **' + str(ind_all_trans_rate))   
		all_trans_rate_ls.append(ind_all_trans_rate)
		if args["test_trans_rate_only"]:
			print("Program terminates after checking the transfer rate!")
			sys.exit(0)
		if not args["force_tune_baseline"]:
			if all_trans_rate > use_loc_adv_thres:
				print("Updated the starting points to local AEs....")
				start_points[np.logical_not(attacked_flag)] = local_aes
				use_loc_adv_flag = True

	# initial fine-tuning set obtained from querying the target model
	S = np.copy(start_points)
	S_label = target_model.predict_prob(S)
	S_label_cate = np.argmax(S_label,axis = 1)
	S_label_cate = np_utils.to_categorical(S_label_cate, class_num)

	# stores the order of images attacked (record the image index)
	candi_idx_ls = []
	pre_free_idx = []
	# these parameters are used to make sure equal number of instances from each class are selected
	# such that diversity of fine-tuning set is improved. However, it is not effective...
	per_cls_cnt = 0
	cls_order = 0
	change_limit = False
	max_lim_num = int(fine_tune_freq/class_num) 

	# store gradient black-box attack results
	out_dir_prefix = os.path.join(args["save_path"], args["attack_method"],target_model_name,
												args["attack_type"],str(args["seed"])) 
	os.system("mkdir -p {}".format(out_dir_prefix)) 

	######### main loop of hybrid attack ###########
	for itr in range(len(orig_labels)):
		print("#------------ Attack Round {} ----------------#".format(itr))
		if is_targeted:
			img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
			target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
		else:
			img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
			orig_labels,targeted=is_targeted,load_robust=load_robust)

		free_idx_diff = list(set(free_idx) - set(pre_free_idx))
		print("new free_idx found:",free_idx_diff)
		if len(free_idx_diff) > 0:
			candi_idx_ls.extend(free_idx_diff)
		pre_free_idx = free_idx
		
		if with_local:
			if len(free_idx)>0:
				# free attacks are found
				attacked_flag[free_idx] = 1 
				success_vec[free_idx] = 1
				# update the distance and adv class
				if args['dist_metric'] == 'l2':
					dist = np.sum((start_points[free_idx]-orig_images[free_idx])**2,axis = (1,2,3))**.5
				elif args['dist_metric'] == 'li':
					dist = np.amax(np.abs(start_points[free_idx] - orig_images[free_idx]),axis = (1,2,3))
				adv_class = target_model.pred_class(start_points[free_idx])
				adv_classes[free_idx]= adv_class
				dist_record[free_idx] = dist 
				if np.amax(dist) >= args["cost_threshold"] + args["cost_threshold"]/10:
					print("there are some problems in setting the perturbation distance!")
					sys.exit(0)
		print("Number of Unattacked Seeds: ",np.sum(np.logical_not(attacked_flag)))
		if attacked_flag.all():
			# early stop when seeds are sucessfully attacked
			break

		if args["sort_metric"] == "min":
			img_loss[attacked_flag] = 1e10
		elif args["sort_metric"] == "max":
			img_loss[attacked_flag] = -1e10
		candi_idx, per_cls_cnt, cls_order,change_limit,max_lim_num = select_next_seed(img_loss,attacked_flag,args["sort_metric"],\
		args["by_class"],fine_tune_freq,class_num,per_cls_cnt,cls_order,change_limit,max_lim_num)

		print(candi_idx)
		candi_idx_ls.append(candi_idx)
		input_img = start_points[candi_idx:candi_idx+1]
		
		if args["attack_method"] == "autozoom":
			# encoder performance check
			encode_img = encoder.predict(input_img)
			decode_img = decoder.predict(encode_img)
			diff_img = (decode_img - input_img)
			diff_mse = np.mean(diff_img.reshape(-1)**2)
		else:
			diff_mse = 0.0
		print("[Info][Start]: test_index:{}, true label:{}, target label:{}, MSE:{}".format(candi_idx, np.argmax(orig_labels[candi_idx]),\
			np.argmax(target_ys_one_hot[candi_idx]),diff_mse))
		################## BEGIN: bbox attacks ############################
		if args["attack_method"] == "autozoom":
			if is_targeted:
				x_s, ae, query_num = autozoom_attack(blackbox_attack,input_img,orig_images[candi_idx:candi_idx+1],target_ys_one_hot[candi_idx])
			else:
				x_s, ae, query_num = autozoom_attack(blackbox_attack,input_img,orig_images[candi_idx:candi_idx+1],orig_labels[candi_idx])
		else:
			if is_targeted:
				x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],np.argmax(target_ys_one_hot[candi_idx]),\
				lower = clip_min, upper = clip_max)
			else:
				x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],np.argmax(orig_labels[candi_idx]),\
				lower = clip_min, upper = clip_max)
			x_s = np.squeeze(np.array(x_s),axis = 1)
		################## END: bbox attacks ############################

		attacked_flag[candi_idx] = 1
		# fill the query info, etc
		if len(ae.shape) == 3:
			ae = np.expand_dims(ae, axis=0)
		if args['dist_metric'] == 'l2':
			dist = np.sum((ae-orig_images[candi_idx])**2)**.5
		elif args['dist_metric'] == 'li':
			dist = np.amax(np.abs(ae-orig_images[candi_idx]))
		adv_class = target_model.pred_class(ae)
		adv_classes[candi_idx] = adv_class
		query_num_vec[candi_idx] += query_num 

		if dist >= args["cost_threshold"] + args["cost_threshold"]/10 or math.isnan(dist):
			print("the distance is not optimized properly")
			sys.exit(0)

		if is_targeted:
			if adv_class == np.argmax(target_ys_one_hot[candi_idx]):
				success_vec[candi_idx] = 1
		else:
			if adv_class != np.argmax(orig_labels[candi_idx]):
				success_vec[candi_idx] = 1
		
		if attacked_flag.all():
			print("Early termination because all seeds are successfully attacked!")
			break
		##############################################################
		## Starts the section of substitute training and local advs ##
		##############################################################
		if with_local:
			if not stop_fine_tune_flag:
				# augment the local model training data with target model labels
				S = np.concatenate((S, np.array(x_s)), axis=0)        
				S_label_add = target_model.predict_prob(np.array(x_s))
				S_label_add_cate = np.argmax(S_label_add,axis = 1)
				S_label_add_cate = np_utils.to_categorical(S_label_add_cate, class_num)
				S_label_cate = np.concatenate((S_label_cate, S_label_add_cate), axis=0)
				S_label = np.concatenate((S_label, S_label_add), axis=0)            
				# fine-tune the local models
				if itr % fine_tune_freq == 0 and itr != 0:
					if len(S_label) > args["train_inst_lim"]:
						curr_len = len(S_label)
						rand_idx = np.random.choice(len(S_label),args["train_inst_lim"],replace = False)
						S = S[rand_idx]
						S_label = S_label[rand_idx]
						S_label_cate = S_label_cate[rand_idx]  
						print("current num: %d, max train instance limit %d is reached, performed random sampling to get %d samples!" % (curr_len,len(S_label),len(rand_idx)))  
					
					print("Train on %d data and validate on %d data" % (len(S_label),len(y_test)))
					sss = 0
					for loc_model in local_model_ls:
						# model_name = local_model_names[sss]
						model_name = load_model_names[sss]
						if args["no_save_model"]:
							loc_model.model.fit(S, S_label,
							batch_size=args["train_batch_size"],
							epochs=sub_epochs,
							verbose=0,
							validation_data=(x_test, y_test),
							shuffle = True)  
						else:
							callbacks = callbacks_ls[sss]
							loc_model.model.fit(S, S_label,
								batch_size=args["train_batch_size"],
								epochs=sub_epochs,
								verbose=0,
								validation_data=(x_test, y_test),
								shuffle = True,
								callbacks = callbacks)     
						scores = loc_model.model.evaluate(x_test, y_test, verbose=0)
						print('Test accuracy of model {}: {:.4f}'.format(model_name,scores[1]))
						sss += 1

					if not attacked_flag.all():
						# first check for not attacked seeds
						if is_targeted:
							remain_trans_rate, pred_labs, remain_local_aes,_, _, _,\
             				_, _, _, _ = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
							target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = 500,\
							attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
						else:
							remain_trans_rate, pred_labs, remain_local_aes,_,_, _,\
             				_, _, _, _ = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
							orig_labels[np.logical_not(attacked_flag)],eval_batch_size = 500,\
							attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
						if not is_targeted:
							remain_trans_rate = 1 - remain_trans_rate
						print('<<Ramaining Seed Transfer Rate>>:**' + str(remain_trans_rate))
						# if transfer rate is higher than threshold, use local advs as starting points
						if remain_trans_rate <=0 and use_loc_adv_flag:
							print("No improvement from substitue training, stop fine-tuning!")
							stop_fine_tune_flag = False

						# transfer rate check with independent test examples
						if is_targeted:
							all_trans_rate, pred_labs, _,_, _, _,\
             			_, _, _, _ = local_attack_in_batches(sess,x_trans_inputs,target_ys_one_hot,eval_batch_size = 500,\
							attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
						else:
							all_trans_rate, pred_labs, _,_, _, _,\
             			_, _, _, _ = local_attack_in_batches(sess,x_trans_inputs,orig_labels,eval_batch_size = 500,\
							attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
						if not is_targeted:
							all_trans_rate = 1 - all_trans_rate
						print('<<Overall Transfer Rate>>: **'+str(all_trans_rate))
						
						# if trans rate is not high enough, still start from orig seed; start from loc adv only 
						# when trans rate is high enough
						if not args["force_tune_baseline"]:
							if not use_loc_adv_flag:
								if remain_trans_rate > use_loc_adv_thres: 
									use_loc_adv_flag = True
									print("Updated the starting points....")
									start_points[np.logical_not(attacked_flag)] = remain_local_aes
								# record the queries spent on checking newly generated loc advs
								query_num_vec += 1
							else:
								print("Updated the starting points....")
								start_points[np.logical_not(attacked_flag)] = remain_local_aes
								# record the queries spent on checking newly generated loc advs
								query_num_vec[np.logical_not(attacked_flag)] += 1
						remain_trans_rate_ls.append(remain_trans_rate)
						all_trans_rate_ls.append(all_trans_rate)
			np.set_printoptions(precision=4)
			print("all_trans_rate:")
			print(all_trans_rate_ls)
			print("remain_trans_rate")
			print(remain_trans_rate_ls)

	# save the query information of all classes
	if not args["no_save_text"]:
		save_name_file = os.path.join(out_dir_prefix,"{}_num_queries.txt".format(loc_adv))
		np.savetxt(save_name_file, query_num_vec,fmt='%d',delimiter=' ')
		save_name_file = os.path.join(out_dir_prefix,"{}_success_flags.txt".format(loc_adv))
		np.savetxt(save_name_file, success_vec,fmt='%d',delimiter=' ')
def main(args):
    if args["model_type"] == "normal":
        load_robust = False
    else:
        load_robust = True
    simple_target_model = args[
        "simple_target_model"]  # if true, target model is simple CIAR10 model (LeNet)
    simple_local_model = True  # if true, local models are simple CIFAR10 models (LeNet)

    # Set TF random seed to improve reproducibility
    tf.set_random_seed(args["seed"])
    data = CIFAR()
    if not hasattr(K, "tf"):
        raise RuntimeError("This tutorial requires keras to be configured"
                           " to use the TensorFlow backend.")

    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
              "'th', temporarily setting to 'tf'")

    # Create TF session and set as Keras backend session
    sess = tf.Session()
    keras.backend.set_session(sess)

    x_test, y_test = CIFAR().test_data, CIFAR().test_labels

    all_trans_rate_ls = []  # store transfer rate of all seeds
    remain_trans_rate_ls = [
    ]  # store transfer rate of remaining seeds, used only in local model fine-tuning

    # Define input TF placeholders
    class_num = 10
    image_size = 32
    num_channels = 3
    test_batch_size = 100
    x = tf.placeholder(tf.float32,
                       shape=(None, image_size, image_size, num_channels))
    y = tf.placeholder(tf.float32, shape=(None, class_num))
    # required by the local robust densenet model
    is_training = tf.placeholder(tf.bool, shape=[])
    keep_prob = tf.placeholder(tf.float32)
    ########################### load the target model ##########################################
    if not load_robust:
        if simple_target_model:
            target_model_name = 'modelA'
            target_model = cifar10_models_simple(sess,test_batch_size, 0, use_softmax=True,x = x, y = y,\
            load_existing=True,model_name=target_model_name)
        else:
            target_model_name = 'densenet'
            target_model = cifar10_models(sess,0,test_batch_size = test_batch_size,use_softmax=True,x = x, y = y,\
            load_existing=True,model_name=target_model_name)
        accuracy = target_model.calcu_acc(x_test, y_test)
        print('Test accuracy of target model {}: {:.4f}'.format(
            target_model_name, accuracy))
    else:
        if args["robust_type"] == "madry":
            target_model_name = 'madry_robust'
            model_dir = "CIFAR10_models/Robust_Deep_models/Madry_robust_target_model"  # TODO: pur your own madry robust target model directory here
            target_model = Load_Madry_Model(sess,
                                            model_dir,
                                            bias=0.5,
                                            scale=255)
        elif args["robust_type"] == "zico":
            # Note: add zico cifar10 model will added in future
            target_model_name = 'zico_robust'
            model_dir = ""  # TODO: put your own robust zico target model directory here
            target_model = Load_Zico_Model(model_dir=model_dir,
                                           bias=0.5,
                                           scale=255)
        else:
            raise NotImplementedError
        corr_preds = target_model.correct_prediction(x_test,
                                                     np.argmax(y_test, axis=1))
        print('Test accuracy of target robust model :{:.4f}'.format(
            np.sum(corr_preds) / len(x_test)))
    ##################################### end of load target model ###################################
    local_model_names = args["local_model_names"]
    robust_indx = []
    normal_local_types = []
    for loc_model_name in local_model_names:
        if loc_model_name == "adv_densenet" or loc_model_name == "adv_vgg" or loc_model_name == "adv_resnet":
            # normal_local_types.append(0)
            robust_indx.append(1)
        else:
            robust_indx.append(0)
            if loc_model_name == "modelB":
                normal_local_types.append(1)
            elif loc_model_name == "modelD":
                normal_local_types.append(3)
            elif loc_model_name == "modelE":
                normal_local_types.append(4)
    print("robust index: ", robust_indx)
    print("normal model types:", normal_local_types)

    local_model_folder = ''
    for ii in range(len(local_model_names)):
        if ii != len(local_model_names) - 1:
            local_model_folder += local_model_names[ii] + '_'
        else:
            local_model_folder += local_model_names[ii]

    nb_imgs = args["num_img"]
    # local model attack related params
    clip_min = -0.5
    clip_max = 0.5
    li_eps = args["cost_threshold"]
    alpha = 1.0
    k = 100
    a = 0.01

    load_existing = True  # load pretrained local models, if false, random model will be given
    with_local = args[
        "with_local"]  # if true, hybrid attack, otherwise, only baseline attacks
    if args["no_tune_local"]:
        stop_fine_tune_flag = True
        load_existing = True
    else:
        stop_fine_tune_flag = False

    if with_local:
        if load_existing:
            loc_adv = 'adv_with_tune'
        if args["no_tune_local"]:
            loc_adv = 'adv_no_tune'
    else:
        loc_adv = 'orig'

    # target type
    if args["attack_type"] == "targeted":
        is_targeted = True
    else:
        is_targeted = False

    sub_epochs = args["nb_epochs_sub"]  # epcohs for local model training
    use_loc_adv_thres = args[
        "use_loc_adv_thres"]  # threshold for transfer attack success rate, it is used when we need to start from local adversarial seeds
    use_loc_adv_flag = True  # flag for using local adversarial examples
    fine_tune_freq = args[
        "fine_tune_freq"]  # fine-tune the model every K images to save total model training time

    # store the attack input files (e.g., original image, target class)
    input_file_prefix = os.path.join(args["local_path"], target_model_name,
                                     args["attack_type"])
    os.system("mkdir -p {}".format(input_file_prefix))
    # save locally generated information
    local_info_file_prefix = os.path.join(args["local_path"],
                                          target_model_name,
                                          args["attack_type"],
                                          local_model_folder,
                                          str(args["seed"]))
    os.system("mkdir -p {}".format(local_info_file_prefix))
    # attack_input_file_prefix = os.path.join(args["local_path"],target_model_name,
    # 											args["attack_type"])
    # save bbox attack information
    out_dir_prefix = os.path.join(args["save_path"], args["attack_method"],
                                  target_model_name, args["attack_type"],
                                  local_model_folder, str(args["seed"]))
    os.system("mkdir -p {}".format(out_dir_prefix))

    #### generate the original images and target classes ####
    target_ys_one_hot,orig_images,target_ys,orig_labels,_, trans_test_images = \
    generate_attack_inputs(sess,target_model,x_test,y_test,class_num,nb_imgs,\
     load_imgs=args["load_imgs"],load_robust=load_robust,\
      file_path = input_file_prefix)
    #### end of genarating original images and target classes ####

    start_points = np.copy(
        orig_images)  # either start from orig seed or local advs
    # store attack statistical info
    dist_record = np.zeros(len(orig_labels), dtype=float)
    query_num_vec = np.zeros(len(orig_labels), dtype=int)
    success_vec = np.zeros(len(orig_labels), dtype=bool)
    adv_classes = np.zeros(len(orig_labels), dtype=int)

    # local model related variables
    if simple_target_model:
        local_model_file_name = "cifar10_simple"
    elif load_robust:
        local_model_file_name = "cifar10_robust"
    else:
        local_model_file_name = "cifar10"
    # save_dir = 'model/'+local_model_file_name + '/'
    callbacks_ls = []
    attacked_flag = np.zeros(len(orig_labels), dtype=bool)

    local_model_ls = []
    if with_local:
        ###################### start loading local models ###############################
        local_model_names_all = []  # help to store complete local model names
        sss = 0
        for model_name in local_model_names:
            if model_name == "adv_densenet" or model_name == "adv_vgg" or model_name == "adv_resnet":
                # tensoflow based robust local models
                loc_model = cifar10_tf_robust_models(sess, test_batch_size = test_batch_size, x = x,y = y, is_training=is_training,keep_prob=keep_prob,\
                 load_existing = True, model_name = model_name,loss = args["loss_function"])
                accuracy = loc_model.calcu_acc(x_test, y_test)
                local_model_ls.append(loc_model)
                print('Test accuracy of model {}: {:.4f}'.format(
                    model_name, accuracy))
                sss += 1
            else:
                # keras based local normal models
                if simple_local_model:
                    type_num = normal_local_types[sss]
                if model_name == 'resnet_v1' or model_name == 'resnet_v2':
                    depth_s = [20, 50, 110]
                else:
                    depth_s = [0]
                for depth in depth_s:
                    # model_name used for loading models
                    if model_name == 'resnet_v1' or model_name == 'resnet_v2':
                        model_load_name = model_name + str(depth)
                    else:
                        model_load_name = model_name
                    local_model_names_all.append(model_load_name)
                    if not simple_local_model:
                        loc_model = cifar10_models(sess,depth,test_batch_size = test_batch_size,use_softmax = True, x = x,y = y,\
                        load_existing = load_existing, model_name = model_name,loss = args["loss_function"])
                    else:
                        loc_model = cifar10_models_simple(sess,test_batch_size,type_num,use_softmax = True, x = x,y = y,\
                        is_training=is_training,keep_prob=keep_prob,load_existing = load_existing, model_name = model_name, loss = args["loss_function"])
                    local_model_ls.append(loc_model)

                    opt = keras.optimizers.SGD(lr=0.01,
                                               decay=1e-6,
                                               momentum=0.9,
                                               nesterov=True)
                    loc_model.model.compile(loss='categorical_crossentropy',
                                            optimizer=opt,
                                            metrics=['accuracy'])
                    orig_images_nw = orig_images
                    orig_labels_nw = orig_labels
                    if args["no_save_model"]:
                        if not load_existing:
                            loc_model.model.fit(
                                orig_images_nw,
                                orig_labels_nw,
                                batch_size=args["train_batch_size"],
                                epochs=sub_epochs,
                                verbose=0,
                                validation_data=(x_test, y_test),
                                shuffle=True)
                    else:
                        print(
                            "Saving local model is yet to be implemented, please check back later, system exiting!"
                        )
                        sys.exit(0)
                        # TODO: fix the issue of loading pretrained model first and then finetune the model
                        # if load_existing:
                        # 	filepath = save_dir + model_load_name + '_pretrained.h5'
                        # else:
                        # 	filepath = save_dir + model_load_name + '.h5'
                        # checkpoint = ModelCheckpoint(filepath=filepath,
                        # 							monitor='val_acc',
                        # 							verbose=0,
                        # 							save_best_only=True)
                        # callbacks = [checkpoint]
                        # callbacks_ls.append(callbacks)
                        # if not load_existing:
                        # 	print("Train on %d data and validate on %d data" % (len(orig_labels_nw),len(y_test)))
                        # 	loc_model.model.fit(orig_images_nw, orig_labels_nw,
                        # 		batch_size=args["train_batch_size"],
                        # 		epochs=sub_epochs,
                        # 		verbose=0,
                        # 		validation_data=(x_test, y_test),
                        # 		shuffle = True,
                        # 		callbacks = callbacks)
                    scores = loc_model.model.evaluate(x_test,
                                                      y_test,
                                                      verbose=0)
                    accuracy = scores[1]
                    print('Test accuracy of model {}: {:.4f}'.format(
                        model_load_name, accuracy))
                    sss += 1
        ##################### end of loading local models ######################################

        ##################### Define Attack Graphs of local PGD attack ###############################
        local_attack_graph = LinfPGDAttack(local_model_ls,
                                           epsilon=li_eps,
                                           k=k,
                                           a=a,
                                           random_start=True,
                                           loss_func=args["loss_function"],
                                           targeted=is_targeted,
                                           robust_indx=robust_indx,
                                           x=x,
                                           y=y,
                                           is_training=is_training,
                                           keep_prob=keep_prob)

        ##################### end of definining graphsof PGD attack ##########################

        ##################### generate local adversarial examples and also store the local attack information #####################
        if not args["load_local_AEs"]:
            # first do the transfer check to obtain local adversarial samples
            # generated local info can be used for batch attacks,
            # max_loss, min_loss, max_gap, min_gap etc are other metrics we explored for scheduling seeds based on local information
            if is_targeted:
                all_trans_rate, pred_labs, local_aes,pgd_cnt_mat, max_loss, min_loss, ave_loss, max_gap, min_gap, ave_gap\
                  = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
                target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
            else:
                all_trans_rate, pred_labs, local_aes,pgd_cnt_mat, max_loss, min_loss, ave_loss, max_gap, min_gap, ave_gap\
                  = local_attack_in_batches(sess,start_points[np.logical_not(attacked_flag)],\
                orig_labels[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
            # calculate local adv loss used for scheduling seeds in batch attack...
            if is_targeted:
                adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
                target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
            else:
                adv_img_loss, free_idx = compute_cw_loss(sess,target_model,local_aes,\
                orig_labels,targeted=is_targeted,load_robust=load_robust)

            # calculate orig img loss for scheduling seeds in baseline attack
            if is_targeted:
                orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
                target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
            else:
                orig_img_loss, free_idx = compute_cw_loss(sess,target_model,orig_images,\
                orig_labels,targeted=is_targeted,load_robust=load_robust)

            pred_labs = np.argmax(target_model.predict_prob(local_aes), axis=1)
            if is_targeted:
                transfer_flag = np.argmax(target_ys_one_hot,
                                          axis=1) == pred_labs
            else:
                transfer_flag = np.argmax(orig_labels, axis=1) != pred_labs
            # save local aes
            np.save(local_info_file_prefix + '/local_aes.npy', local_aes)
            # store local info of local aes and original seeds: used for scheduling seeds in batch attacks
            np.savetxt(local_info_file_prefix + '/pgd_cnt_mat.txt',
                       pgd_cnt_mat)
            np.savetxt(local_info_file_prefix + '/orig_img_loss.txt',
                       orig_img_loss)
            np.savetxt(local_info_file_prefix + '/adv_img_loss.txt',
                       adv_img_loss)
            np.savetxt(local_info_file_prefix + '/ave_gap.txt', ave_gap)
        else:
            local_aes = np.load(local_info_file_prefix + '/local_aes.npy')
            if is_targeted:
                tmp_labels = target_ys_one_hot
            else:
                tmp_labels = orig_labels
            pred_labs = np.argmax(target_model.predict_prob(
                np.array(local_aes)),
                                  axis=1)
            print('correct number',
                  np.sum(pred_labs == np.argmax(tmp_labels, axis=1)))
            all_trans_rate = accuracy_score(np.argmax(tmp_labels, axis=1),
                                            pred_labs)
        ################################ end of generating local AEs and storing related information #######################################

        if not is_targeted:
            all_trans_rate = 1 - all_trans_rate
        print('** Transfer Rate: **' + str(all_trans_rate))

        if all_trans_rate > use_loc_adv_thres:
            print("Updated the starting points to local AEs....")
            start_points[np.logical_not(attacked_flag)] = local_aes
            use_loc_adv_flag = True

        # independent test set for checking transferability: for experiment purpose and does not count for query numbers
        if is_targeted:
            ind_all_trans_rate,_,_,_,_,_,_,_,_,_ = local_attack_in_batches(sess,trans_test_images,target_ys_one_hot,eval_batch_size = test_batch_size,\
            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
        else:
            ind_all_trans_rate,_,_,_,_,_,_,_,_,_ = local_attack_in_batches(sess,trans_test_images,orig_labels,eval_batch_size = test_batch_size,\
            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)

        # record the queries spent by quering the local samples
        query_num_vec[np.logical_not(attacked_flag)] += 1
        if not is_targeted:
            ind_all_trans_rate = 1 - ind_all_trans_rate
        print('** (Independent Set) Transfer Rate: **' +
              str(ind_all_trans_rate))
        all_trans_rate_ls.append(ind_all_trans_rate)

    S = np.copy(start_points)
    S_label = target_model.predict_prob(S)
    S_label_cate = np.argmax(S_label, axis=1)
    S_label_cate = np_utils.to_categorical(S_label_cate, class_num)

    pre_free_idx = []
    candi_idx_ls = []  # store the indices of images in the order attacked

    # these parameters are used to make sure equal number of instances from each class are selected
    # such that diversity of fine-tuning set is improved. However, it is not effective...
    per_cls_cnt = 0
    cls_order = 0
    change_limit = False
    max_lim_num = int(fine_tune_freq / class_num)

    # define the autozoom bbox attack graph
    if args["attack_method"] == "autozoom":
        # setup the autoencoders for autozoom attack
        codec = 0
        args["img_resize"] = 8
        # replace with your directory
        codec_dir = 'CIFAR10_models/cifar10_autoencoder/'  # TODO: replace with your own cifar10 autoencoder directory
        encoder = load_model(codec_dir + 'whole_cifar10_encoder.h5')
        decoder = load_model(codec_dir + 'whole_cifar10_decoder.h5')

        encode_img = encoder.predict(data.test_data[100:101])
        decode_img = decoder.predict(encode_img)
        diff_img = (decode_img - data.test_data[100:101])
        diff_mse = np.mean(diff_img.reshape(-1)**2)

        # diff_mse = np.mean(np.sum(diff_img.reshape(-1,784)**2,axis = 1))
        print("[Info][AE] MSE:{:.4f}".format(diff_mse))
        encode_img = encoder.predict(data.test_data[0:1])
        decode_img = decoder.predict(encode_img)
        diff_img = (decode_img - data.test_data[0:1])
        diff_mse = np.mean(diff_img.reshape(-1)**2)
        print("[Info][AE] MSE:{:.4f}".format(diff_mse))

    if args["attack_method"] == "autozoom":
        # define black-box model graph of autozoom
        autozoom_graph = AutoZOOM(sess, target_model, args, decoder, codec,
                                  num_channels, image_size, class_num)

    # main loop of hybrid attacks
    for itr in range(len(orig_labels)):
        print("#------------ Substitue training round {} ----------------#".
              format(itr))
        # computer loss functions of seeds: no query is needed here because seeds are already queried before...
        if is_targeted:
            img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
            target_ys_one_hot,targeted=is_targeted,load_robust=load_robust)
        else:
            img_loss, free_idx = compute_cw_loss(sess,target_model,start_points,\
            orig_labels,targeted=is_targeted,load_robust=load_robust)
        free_idx_diff = list(set(free_idx) - set(pre_free_idx))
        print("new free idx found:", free_idx_diff)
        if len(free_idx_diff) > 0:
            candi_idx_ls.extend(free_idx_diff)
        pre_free_idx = free_idx
        if with_local:
            if len(free_idx) > 0:
                # free attacks are found
                attacked_flag[free_idx] = 1
                success_vec[free_idx] = 1
                # update dist and adv class
                if args['dist_metric'] == 'l2':
                    dist = np.sum(
                        (start_points[free_idx] - orig_images[free_idx])**2,
                        axis=(1, 2, 3))**.5
                elif args['dist_metric'] == 'li':
                    dist = np.amax(np.abs(start_points[free_idx] -
                                          orig_images[free_idx]),
                                   axis=(1, 2, 3))
                # print(start_points[free_idx].shape)
                adv_class = target_model.pred_class(start_points[free_idx])
                adv_classes[free_idx] = adv_class
                dist_record[free_idx] = dist
                if np.amax(
                        dist
                ) >= args["cost_threshold"] + args["cost_threshold"] / 10:
                    print(
                        "there are some problems in setting the perturbation distance!"
                    )
                    sys.exit(0)
        print("Number of Unattacked Seeds: ",
              np.sum(np.logical_not(attacked_flag)))
        if attacked_flag.all():
            # early stop when all seeds are sucessfully attacked
            break

        # define the seed generation process as a functon
        if args["sort_metric"] == "min":
            img_loss[attacked_flag] = 1e10
        elif args["sort_metric"] == "max":
            img_loss[attacked_flag] = -1e10
        candi_idx, per_cls_cnt, cls_order,change_limit,max_lim_num = select_next_seed(img_loss,attacked_flag,args["sort_metric"],\
        args["by_class"],fine_tune_freq,class_num,per_cls_cnt,cls_order,change_limit,max_lim_num)

        print(candi_idx)
        candi_idx_ls.append(candi_idx)

        input_img = start_points[candi_idx:candi_idx + 1]
        if args["attack_method"] == "autozoom":
            # encoder decoder performance check
            encode_img = encoder.predict(input_img)
            decode_img = decoder.predict(encode_img)
            diff_img = (decode_img - input_img)
            diff_mse = np.mean(diff_img.reshape(-1)**2)
        else:
            diff_mse = 0.0

        print("[Info][Start]: test_index:{}, true label:{}, target label:{}, MSE:{}".format(candi_idx, np.argmax(orig_labels[candi_idx]),\
         np.argmax(target_ys_one_hot[candi_idx]),diff_mse))

        ################## BEGIN: bbox attacks ############################
        if args["attack_method"] == "autozoom":
            # perform bbox attacks
            if is_targeted:
                x_s, ae, query_num = autozoom_attack(
                    autozoom_graph, input_img,
                    orig_images[candi_idx:candi_idx + 1],
                    target_ys_one_hot[candi_idx])
            else:
                x_s, ae, query_num = autozoom_attack(
                    autozoom_graph, input_img,
                    orig_images[candi_idx:candi_idx + 1],
                    orig_labels[candi_idx])
        else:
            if is_targeted:
                x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],\
                 np.argmax(target_ys_one_hot[candi_idx]), lower = clip_min, upper = clip_max)
            else:
                x_s, query_num, ae = nes_attack(args,target_model,input_img,orig_images[candi_idx:candi_idx+1],\
                 np.argmax(orig_labels[candi_idx]), lower = clip_min, upper = clip_max)
            x_s = np.squeeze(np.array(x_s), axis=1)
        ################## END: bbox attacks ############################

        attacked_flag[candi_idx] = 1

        # fill the query info, etc
        if len(ae.shape) == 3:
            ae = np.expand_dims(ae, axis=0)
        if args['dist_metric'] == 'l2':
            dist = np.sum((ae - orig_images[candi_idx])**2)**.5
        elif args['dist_metric'] == 'li':
            dist = np.amax(np.abs(ae - orig_images[candi_idx]))
        adv_class = target_model.pred_class(ae)
        adv_classes[candi_idx] = adv_class
        dist_record[candi_idx] = dist

        if args["attack_method"] == "autozoom":
            # autozoom utilizes the query info of attack input, which is already done at the begining.
            added_query = query_num - 1
        else:
            added_query = query_num

        query_num_vec[candi_idx] += added_query
        if dist >= args["cost_threshold"] + args["cost_threshold"] / 10:
            print("the distance is not optimized properly")
            sys.exit(0)

        if is_targeted:
            if adv_class == np.argmax(target_ys_one_hot[candi_idx]):
                success_vec[candi_idx] = 1
        else:
            if adv_class != np.argmax(orig_labels[candi_idx]):
                success_vec[candi_idx] = 1
        if attacked_flag.all():
            print(
                "Early termination because all seeds are successfully attacked!"
            )
            break
        ##############################################################
        ## Starts the section of substitute training and local advs ##
        ##############################################################
        if with_local:
            if not stop_fine_tune_flag:
                # augment the local model training data with target model labels
                print(np.array(x_s).shape)
                print(S.shape)
                S = np.concatenate((S, np.array(x_s)), axis=0)
                S_label_add = target_model.predict_prob(np.array(x_s))
                S_label_add_cate = np.argmax(S_label_add, axis=1)
                S_label_add_cate = np_utils.to_categorical(
                    S_label_add_cate, class_num)
                S_label_cate = np.concatenate((S_label_cate, S_label_add_cate),
                                              axis=0)
                # empirically, tuning with model prediction probabilities given slightly better results.
                # if your bbox attack is decision based, only use the prediction labels
                S_label = np.concatenate((S_label, S_label_add), axis=0)
                # fine-tune the model
                if itr % fine_tune_freq == 0 and itr != 0:
                    if len(S_label) > args["train_inst_lim"]:
                        curr_len = len(S_label)
                        rand_idx = np.random.choice(len(S_label),
                                                    args["train_inst_lim"],
                                                    replace=False)
                        S = S[rand_idx]
                        S_label = S_label[rand_idx]
                        S_label_cate = S_label_cate[rand_idx]
                        print(
                            "current num: %d, max train instance limit %d is reached, performed random sampling to get %d samples!"
                            % (curr_len, len(S_label), len(rand_idx)))
                    sss = 0

                    for loc_model in local_model_ls:
                        model_name = local_model_names_all[sss]
                        if args["use_mixup"]:
                            print(
                                "Updates the training data with mixup strayegy!"
                            )
                            S_nw = np.copy(S)
                            S_label_nw = np.copy(S_label)
                            S_nw, S_label_nw, _ = mixup_data(S_nw,
                                                             S_label_nw,
                                                             alpha=alpha)
                        else:
                            S_nw = S
                            S_label_nw = S_label
                        print("Train on %d data and validate on %d data" %
                              (len(S_label_nw), len(y_test)))
                        if args["no_save_model"]:
                            loc_model.model.fit(
                                S_nw,
                                S_label_nw,
                                batch_size=args["train_batch_size"],
                                epochs=sub_epochs,
                                verbose=0,
                                validation_data=(x_test, y_test),
                                shuffle=True)
                        else:
                            print(
                                "Saving local model is yet to be implemented, please check back later, system exiting!"
                            )
                            sys.exit(0)
                            # callbacks = callbacks_ls[sss]
                            # loc_model.model.fit(S_nw, S_label_nw,
                            # 	batch_size=args["train_batch_size"],
                            # 	epochs=sub_epochs,
                            # 	verbose=0,
                            # 	validation_data=(x_test, y_test),
                            # 	shuffle = True,
                            # 	callbacks = callbacks)
                        scores = loc_model.model.evaluate(x_test,
                                                          y_test,
                                                          verbose=0)
                        print('Test accuracy of model {}: {:.4f}'.format(
                            model_name, scores[1]))
                        sss += 1
                    if not attacked_flag.all():
                        # first check for not attacked seeds
                        if is_targeted:
                            remain_trans_rate, _, remain_local_aes,_, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
                            target_ys_one_hot[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        else:
                            remain_trans_rate, pred_labs, remain_local_aes,_, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,orig_images[np.logical_not(attacked_flag)],\
                            orig_labels[np.logical_not(attacked_flag)],eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        if not is_targeted:
                            remain_trans_rate = 1 - remain_trans_rate
                        print('<<Ramaining Seed Transfer Rate>>:**' +
                              str(remain_trans_rate))
                        # if transfer rate is higher than threshold, use local advs as starting points
                        if remain_trans_rate <= 0 and use_loc_adv_flag:
                            print(
                                "No improvement for substitue training, stop fine-tuning!"
                            )
                            stop_fine_tune_flag = False

                        # transfer rate check with independent test examples
                        if is_targeted:
                            all_trans_rate, _, _, _, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,trans_test_images,target_ys_one_hot,eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        else:
                            all_trans_rate, _, _, _, _, _, _, _, _, _\
                              = local_attack_in_batches(sess,trans_test_images,orig_labels,eval_batch_size = test_batch_size,\
                            attack_graph = local_attack_graph,model = target_model,clip_min=clip_min,clip_max=clip_max,load_robust=load_robust)
                        if not is_targeted:
                            all_trans_rate = 1 - all_trans_rate
                        print('<<Overall Transfer Rate>>: **' +
                              str(all_trans_rate))

                        # if trans rate is not high enough, still start from orig seed; start from loc adv only
                        # when trans rate is high enough, useful when you start with random model
                        if not use_loc_adv_flag:
                            if remain_trans_rate > use_loc_adv_thres:
                                use_loc_adv_flag = True
                                print("Updated the starting points....")
                                start_points[np.logical_not(
                                    attacked_flag)] = remain_local_aes
                            # record the queries spent on checking newly generated loc advs
                            query_num_vec += 1
                        else:
                            print("Updated the starting points....")
                            start_points[np.logical_not(
                                attacked_flag)] = remain_local_aes
                            # record the queries spent on checking newly generated loc advs
                            query_num_vec[np.logical_not(attacked_flag)] += 1
                        remain_trans_rate_ls.append(remain_trans_rate)
                        all_trans_rate_ls.append(all_trans_rate)
                np.set_printoptions(precision=4)
                print("all_trans_rate:")
                print(all_trans_rate_ls)
                print("remain_trans_rate")
                print(remain_trans_rate_ls)

    # save the query information of all classes
    if not args["no_save_text"]:
        save_name_file = os.path.join(out_dir_prefix,
                                      "{}_num_queries.txt".format(loc_adv))
        np.savetxt(save_name_file, query_num_vec, fmt='%d', delimiter=' ')
        save_name_file = os.path.join(out_dir_prefix,
                                      "{}_success_flags.txt".format(loc_adv))
        np.savetxt(save_name_file, success_vec, fmt='%d', delimiter=' ')
Beispiel #3
0
def main():
    tf.set_random_seed(1234)  # for producing the same images

    if not hasattr(keras.backend, "tf"):
        raise RuntimeError("This tutorial requires keras to be configured"
                           " to use the TensorFlow backend.")

    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' to "
              "'th', temporarily setting to 'tf'")

    sess = tf.Session()
    keras.backend.set_session(sess)

    # load and preprocess dataset
    data_spec = DataSpec(batch_size=TOT_IMAGES,
                         scale_size=256,
                         crop_size=224,
                         isotropic=False)
    image_producer = ImageNetProducer(data_path=INPUT_DIR,
                                      num_images=TOT_IMAGES,
                                      data_spec=data_spec,
                                      batch_size=TOT_IMAGES)

    # Define input TF placeholder
    x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3))
    y = tf.placeholder(tf.float32, shape=(None, 1000))
    class_num = 1000

    # load target model and produce data
    # model = preprocess layer + pretrained model
    from keras.applications.densenet import DenseNet121
    from keras.applications.densenet import preprocess_input
    pretrained_model = DenseNet121(weights='imagenet')
    image_producer.startover()
    target_model = keras_model_wrapper(pretrained_model,
                                       preprocess_input,
                                       x=x,
                                       y=y)
    for (indices, label, names, images) in image_producer.batches(sess):
        images = np.array(images)
        label = np_utils.to_categorical(np.array(label), class_num)
    accuracy = model_eval(sess,
                          x,
                          y,
                          target_model.predictions,
                          images,
                          label,
                          args={'batch_size': 32})
    print('Test accuracy of wrapped target model:{:.4f}'.format(accuracy))

    # data information
    x_test, y_test = images, label  # x_test [0, 255]
    print('loading %s images in total ', images.shape)
    print(np.min(x_test), np.max(x_test))

    # local attack specific parameters
    clip_min = args.lower
    clip_max = args.upper
    nb_imgs = args.nb_imgs
    li_eps = args.epsilon
    targeted_true = True if args.attack_type == 'targeted' else False
    k = args.K  # iteration
    a = args.learning_rate  # step size

    # Test the accuracy of targeted attacks, need to redefine the attack graph
    target_ys_one_hot, orig_images, target_ys, orig_labels = generate_attack_inputs(
        target_model, x_test, y_test, class_num, nb_imgs)

    # Set random seed to improve reproducibility
    tf.set_random_seed(args.seed)
    np.random.seed(args.seed)

    # test whether adversarial examples exsit, if no, generate it, otherwise, load it.
    prefix = "Results"
    prefix = os.path.join(prefix, str(args.seed))

    if not os.path.exists(prefix):  # no history info
        # load local models or define the architecture
        local_model_types = ['VGG16', 'VGG19', 'resnet50']
        local_model_ls = []
        pred_ls = []
        for model_type in local_model_types:
            pretrained_model, preprocess_input_func = load_model(model_type)
            local_model = keras_model_wrapper(pretrained_model,
                                              preprocess_input_func,
                                              x=x,
                                              y=y)
            accuracy = model_eval(sess,
                                  x,
                                  y,
                                  local_model.predictions,
                                  images,
                                  label,
                                  args={'batch_size': 32})
            print('Test accuracy of model {}: {:.4f}'.format(
                model_type, accuracy))
            local_model_ls.append(local_model)
            pred_ls.append(local_model.predictions)

        # load local model attack graph
        if targeted_true:
            orig_img_loss = compute_cw_loss(target_model,
                                            orig_images,
                                            target_ys_one_hot,
                                            targeted=targeted_true)
        else:
            orig_img_loss = compute_cw_loss(target_model,
                                            orig_images,
                                            orig_labels,
                                            targeted=targeted_true)

        local_attack_graph = LinfPGDAttack(local_model_ls,
                                           epsilon=li_eps,
                                           k=k,
                                           a=a,
                                           random_start=False,
                                           loss_func='xent',
                                           targeted=targeted_true,
                                           x=x,
                                           y=y)
        # pgd attack to local models and generate adversarial example seed
        if targeted_true:
            _, pred_labs, local_aes, pgd_cnt_mat, max_loss, \
            min_loss, ave_loss, max_gap, min_gap, ave_gap = local_attack_in_batches(sess,
                              orig_images,
                              target_ys_one_hot,
                              eval_batch_size = 1,
                              attack_graph=local_attack_graph,
                              model=target_model,
                              clip_min=clip_min,
                              clip_max=clip_max)
        else:
            _, pred_labs, local_aes, pgd_cnt_mat, max_loss, \
            min_loss, ave_loss, max_gap, min_gap, ave_gap = local_attack_in_batches(sess,
                              orig_images,
                              orig_labels,
                              eval_batch_size = 1,
                              attack_graph=local_attack_graph,
                              model=target_model,
                              clip_min=clip_min,
                              clip_max=clip_max)

        # calculate the loss for all adversarial seeds
        if targeted_true:
            adv_img_loss = compute_cw_loss(target_model,
                                           local_aes,
                                           target_ys_one_hot,
                                           targeted=targeted_true)
        else:
            adv_img_loss = compute_cw_loss(target_model,
                                           local_aes,
                                           orig_labels,
                                           targeted=targeted_true)

        success_rate = accuracy_score(target_ys, pred_labs)
        print(
            '** Success rate of targeted adversarial examples generated from local models: **'
            + str(success_rate))
        accuracy = accuracy_score(np.argmax(orig_labels, axis=1), pred_labs)
        print(
            '** Success rate of targeted adversarial examples generated by local models (untargeted): **'
            + str(1 - accuracy))

        # l-inf distance of orig_images and local_aes
        dist = local_aes - orig_images
        l_fin_dist = np.linalg.norm(dist.reshape(nb_imgs, -1), np.inf, axis=1)

        # save the generated local adversarial example ...
        os.makedirs(prefix)
        # save statistics
        fname = prefix + '/adv_img_loss.txt'
        np.savetxt(fname, adv_img_loss)
        fname = prefix + '/orig_img_loss.txt'
        np.savetxt(fname, orig_img_loss)
        fname = prefix + '/pgd_cnt_mat.txt'
        np.savetxt(fname, pgd_cnt_mat)
        fname = prefix + '/max_loss.txt'
        np.savetxt(fname, max_loss)
        fname = prefix + '/min_loss.txt'
        np.savetxt(fname, min_loss)
        fname = prefix + '/ave_loss.txt'
        np.savetxt(fname, ave_loss)
        fname = prefix + '/max_gap.txt'
        np.savetxt(fname, max_gap)
        fname = prefix + '/min_gap.txt'
        np.savetxt(fname, min_gap)
        fname = prefix + '/ave_gap.txt'
        np.savetxt(fname, ave_gap)

        # save output for local attacks
        fname = os.path.join(prefix, 'local_aes.npy')
        np.save(fname, local_aes)
        fname = os.path.join(prefix, 'orig_images.npy')
        np.save(fname, orig_images)
        fname = os.path.join(prefix, 'target_ys.npy')
        np.save(fname, target_ys)
        fname = os.path.join(prefix, 'target_ys_one_hot.npy')
        np.save(fname, target_ys_one_hot)
    else:
        print('loading data from files')
        local_aes = np.load(os.path.join(prefix, 'local_aes.npy'))
        orig_images = np.load(os.path.join(prefix, 'orig_images.npy'))
        target_ys = np.load(os.path.join(prefix, 'target_ys.npy'))
        target_ys_one_hot = np.load(
            os.path.join(prefix, 'target_ys_one_hot.npy'))

    assert local_aes.shape == (nb_imgs, 224, 224, 3)
    assert orig_images.shape == (nb_imgs, 224, 224, 3)
    assert target_ys.shape == (nb_imgs, )
    assert target_ys_one_hot.shape == (nb_imgs, class_num)

    print('begin NES attack')
    num_queries_list = []
    success_flags = []
    # fetch batch
    orig_images = orig_images[args.bstart:args.bend]
    target_ys = target_ys[args.bstart:args.bend]
    local_aes = local_aes[args.bstart:args.bend]
    # begin loop
    for idx in range(len(orig_images)):
        initial_img = orig_images[idx:idx + 1]
        target_class = target_ys[idx]
        if args.attack_seed_type == 'adv':
            print('attack seed is %s' % args.attack_seed_type)
            attack_seed = local_aes[idx]
        else:
            print('attack seed is %s' % args.attack_seed_type)
            attack_seed = orig_images[idx]
        _, num_queries, adv = nes_attack(sess, args, target_model, attack_seed,
                                         initial_img, target_class, class_num,
                                         IMAGE_SIZE)
        if num_queries == args.max_queries:
            success_flags.append(0)
        else:
            success_flags.append(1)
        num_queries_list.append(num_queries)

    # save query number and success
    fname = os.path.join(prefix,
                         '{}_num_queries.txt'.format(args.attack_seed_type))
    np.savetxt(fname, num_queries_list)
    fname = os.path.join(prefix,
                         '{}_success_flags.txt'.format(args.attack_seed_type))
    np.savetxt(fname, success_flags)

    print('finish NES attack')