def test_lbfgs_attack(): """ LBFGS-Attack test """ np.random.seed(123) # upload trained network current_dir = os.path.dirname(os.path.abspath(__file__)) ckpt_name = os.path.join( current_dir, '../test_data/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt') net = LeNet5() load_dict = load_checkpoint(ckpt_name) load_param_into_net(net, load_dict) # get one mnist image input_np = np.load( os.path.join(current_dir, '../test_data/test_images.npy'))[:1] label_np = np.load( os.path.join(current_dir, '../test_data/test_labels.npy'))[:1] LOGGER.debug(TAG, 'true label is :{}'.format(label_np[0])) classes = 10 target_np = np.random.randint(0, classes, 1) while target_np == label_np[0]: target_np = np.random.randint(0, classes) target_np = np.eye(10)[target_np].astype(np.float32) attack = LBFGS(net, is_targeted=True) LOGGER.debug(TAG, 'target_np is :{}'.format(target_np[0])) adv_data = attack.generate(input_np, target_np)
def test_pointwise_attack_method(): """ Pointwise attack method unit test. """ np.random.seed(123) # upload trained network current_dir = os.path.dirname(os.path.abspath(__file__)) ckpt_name = os.path.join( current_dir, '../../test_data/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt') net = LeNet5() load_dict = load_checkpoint(ckpt_name) load_param_into_net(net, load_dict) # get one mnist image input_np = np.load( os.path.join(current_dir, '../../test_data/test_images.npy'))[:3] labels = np.load( os.path.join(current_dir, '../../test_data/test_labels.npy'))[:3] model = ModelToBeAttacked(net) pre_label = np.argmax(model.predict(input_np), axis=1) LOGGER.info(TAG, 'original sample predict labels are :{}'.format(pre_label)) LOGGER.info(TAG, 'true labels are: {}'.format(labels)) attack = PointWiseAttack(model, sparse=True, is_targeted=False) is_adv, adv_data, query_times = attack.generate(input_np, pre_label) LOGGER.info( TAG, 'adv sample predict labels are: {}'.format( np.argmax(model.predict(adv_data), axis=1))) assert np.any(adv_data[is_adv][0] != input_np[is_adv][0]), 'Pointwise attack method: ' \ 'generate value must not be equal' \ ' to original value.'
def get_model(current_dir): ckpt_name = os.path.join( current_dir, '../../test_data/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt') net = LeNet5() load_dict = load_checkpoint(ckpt_name) load_param_into_net(net, load_dict) net.set_train(False) model = ModelToBeAttacked(net) return model