def __init__(self, dataset, batch_size, targeted, target_type, epsilon, norm, lower_bound=0.0, upper_bound=1.0,
                 max_queries=10000):
        """
            :param epsilon: perturbation limit according to lp-ball
            :param norm: norm for the lp-ball constraint
            :param lower_bound: minimum value data point can take in any coordinate
            :param upper_bound: maximum value data point can take in any coordinate
            :param max_queries: max number of calls to model per data point
            :param max_crit_queries: max number of calls to early stopping criterion  per data poinr
        """
        assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
        self.epsilon = epsilon
        self.norm = norm
        self.max_queries = max_queries

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self._proj = None
        self.is_new_batch = False
        # self.early_stop_crit_fct = lambda model, x, y: 1 - model(x).max(1)[1].eq(y)
        self.targeted = targeted
        self.target_type = target_type

        self.dataset_loader = DataLoaderMaker.get_test_attacked_data(dataset, batch_size)
        self.total_images = len(self.dataset_loader.dataset)

        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(self.query_all)  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
        self.not_done_prob_all = torch.zeros_like(self.query_all)

        self.lowest_change_ratio = 0.05 # the percentage of pixels that changes
Exemplo n.º 2
0
 def __init__(self,
              dataset,
              batch_size,
              targeted,
              target_type,
              epsilon,
              norm,
              lower_bound=0.0,
              upper_bound=1.0,
              max_queries=10000):
     assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
     self.epsilon = epsilon
     self.norm = norm
     self.max_queries = max_queries
     self.lower_bound = lower_bound
     self.upper_bound = upper_bound
     self.targeted = targeted
     self.target_type = target_type
     self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
         dataset, batch_size)
     self.total_images = len(self.dataset_loader.dataset)
     self.query_all = torch.zeros(self.total_images)
     self.correct_all = torch.zeros_like(self.query_all)  # number of images
     self.not_done_all = torch.zeros_like(
         self.query_all
     )  # always set to 0 if the original image is misclassified
     self.success_all = torch.zeros_like(self.query_all)
     self.success_query_all = torch.zeros_like(self.query_all)
     self.not_done_prob_all = torch.zeros_like(self.query_all)
 def __init__(self, dataset, batch_size, targeted, target_type, epsilon, norm, lower_bound=0.0, upper_bound=1.0,
              max_queries=10000, surrogate_model_names=None):
     assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
     self.epsilon = epsilon
     self.norm = norm
     self.max_queries = max_queries
     self.lower_bound = lower_bound
     self.upper_bound = upper_bound
     self.targeted = targeted
     self.target_type = target_type
     self.dataset_loader = DataLoaderMaker.get_test_attacked_data(dataset, batch_size)
     self.total_images = len(self.dataset_loader.dataset)
     self.query_all = torch.zeros(self.total_images)
     self.correct_all = torch.zeros_like(self.query_all)  # number of images
     self.not_done_all = torch.zeros_like(self.query_all)  # always set to 0 if the original image is misclassified
     self.success_all = torch.zeros_like(self.query_all)
     self.success_query_all = torch.zeros_like(self.query_all)
     self.not_done_prob_all = torch.zeros_like(self.query_all)
     # self.cos_similarity_all = torch.zeros(self.total_images, max_queries)   # N, T
     self.increase_loss_from_last_iter_with_1st_model_grad_record_all = OrderedDict()
     self.increase_loss_from_last_iter_with_2nd_model_grad_record_all = OrderedDict()
     self.increase_loss_from_last_iter_after_switch_record_all = OrderedDict()
     self.surrogate_model_record_all = OrderedDict()
     self.loss_x_pos_temp_record_all = OrderedDict()
     self.loss_x_neg_temp_record_all = OrderedDict()
     self.loss_after_switch_grad_record_all = OrderedDict()
     self.surrogate_model_names = surrogate_model_names
Exemplo n.º 4
0
    def __init__(self, args, directions_generator):
        self.rank_transform = not args.no_rank_transform
        self.random_mask = args.random_mask

        self.image_split = args.image_split
        self.sub_num_sample = args.sub_num_sample
        self.sigma = args.sigma
        self.starting_eps = args.starting_eps
        self.eps = args.epsilon
        self.sample_per_draw = args.sample_per_draw
        self.directions_generator = directions_generator
        self.max_iter = args.max_queries
        self.delta_eps = args.delta_eps
        self.max_lr = args.max_lr
        self.min_lr = args.min_lr
        self.targeted = args.targeted
        self.norm = args.norm

        self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
            args.dataset, 1)
        self.total_images = len(self.dataset_loader.dataset)
        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
        self.not_done_prob_all = torch.zeros_like(self.query_all)
        self.dataset_name = args.dataset
Exemplo n.º 5
0
    def __init__(self, args, dataset, targeted, target_type, epsilon, norm, lower_bound=0.0, upper_bound=1.0,
                 max_queries=10000):
        """
            :param epsilon: perturbation limit according to lp-ball
            :param norm: norm for the lp-ball constraint
            :param lower_bound: minimum value data point can take in any coordinate
            :param upper_bound: maximum value data point can take in any coordinate
            :param max_queries: max number of calls to model per data point
            :param max_crit_queries: max number of calls to early stopping criterion  per data poinr
        """
        assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
        self.epsilon = epsilon
        self.norm = norm
        self.max_queries = max_queries

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        # self.early_stop_crit_fct = lambda model, x, y: 1 - model(x).max(1)[1].eq(y)
        self.targeted = targeted
        self.target_type = target_type

        self.data_loader = DataLoaderMaker.get_test_attacked_data(dataset, args.batch_size)
        self.total_images = len(self.data_loader.dataset)
        self.att_iter = args.max_queries
        self.correct_all = torch.zeros(self.total_images)  # number of images
        self.not_done_all = torch.zeros(self.total_images)  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros(self.total_images)
        self.not_done_prob_all = torch.zeros(self.total_images)
        self.stop_iter_all = torch.zeros(self.total_images)
        self.ord =  args.norm # linf, l1, l2
        self.clip_min = args.clip_min
        self.clip_max = args.clip_max
        self.lr = args.lr
        self.beta1 = args.beta1
        self.loss_fn = nn.CrossEntropyLoss().cuda()
 def __init__(self, dataset_name, model, surrogate_model, meta_model,
              targeted, target_type, meta_predict_steps, finetune_times,
              finetune_lr):
     self.dataset_name = dataset_name
     self.data_loader = DataLoaderMaker.get_test_attacked_data(
         args.dataset, 1)
     self.image_height = IMAGE_SIZE[self.dataset_name][0]
     self.image_width = IMAGE_SIZE[self.dataset_name][1]
     self.in_channels = IN_CHANNELS[self.dataset_name]
     self.model = model
     self.surrogate_model = surrogate_model
     self.model.cuda().eval()
     self.surrogate_model.cuda().eval()
     self.targeted = targeted  # only support untargeted attack now
     self.target_type = target_type
     self.clip_min = 0.0
     self.clip_max = 1.0
     self.meta_predict_steps = meta_predict_steps
     self.finetune_times = finetune_times
     self.meta_model_for_q1 = meta_model
     self.meta_model_for_q2 = copy.deepcopy(meta_model)
     self.finetune_lr = finetune_lr
     self.pretrained_meta_weights = self.meta_model_for_q1.state_dict(
     ).copy()
     self.meta_optimizer_q1 = Adam(self.meta_model_for_q1.parameters(),
                                   lr=self.finetune_lr)
     self.meta_optimizer_q2 = Adam(self.meta_model_for_q2.parameters(),
                                   lr=self.finetune_lr)
     self.mse_loss = nn.MSELoss(reduction="mean")
Exemplo n.º 7
0
    def __init__(self,
                 dataset,
                 targeted,
                 target_type,
                 epsilon,
                 norm,
                 batch_size,
                 lower_bound=0.0,
                 upper_bound=1.0,
                 max_queries=10000,
                 max_crit_queries=np.inf):
        """
            :param epsilon: perturbation limit according to lp-ball
            :param norm: norm for the lp-ball constraint
            :param lower_bound: minimum value data point can take in any coordinate
            :param upper_bound: maximum value data point can take in any coordinate
            :param max_queries: max number of calls to model per data point
            :param max_crit_queries: max number of calls to early stopping criterion  per data poinr
        """
        assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
        assert not (np.isinf(max_queries) and np.isinf(max_crit_queries)
                    ), "one of the budgets has to be finite!"
        self.epsilon = epsilon
        self.norm = norm
        self.max_queries = max_queries
        self.max_crit_queries = max_crit_queries

        self.best_est_deriv = None
        self.xo_t = None
        self.sgn_t = None
        self.h = np.zeros(batch_size).astype(np.int32)
        self.i = np.zeros(batch_size).astype(np.int32)
        self.exhausted = [False for _ in range(batch_size)]

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self._proj = None
        self.is_new_batch = False
        # self.early_stop_crit_fct = lambda model, x, y: 1 - model(x).max(1)[1].eq(y)
        self.targeted = targeted
        self.target_type = target_type

        self.data_loader = DataLoaderMaker.get_test_attacked_data(
            dataset, args.batch_size)
        self.total_images = len(self.data_loader.dataset)
        self.image_height = IMAGE_SIZE[dataset][0]
        self.image_width = IMAGE_SIZE[dataset][1]
        self.in_channels = IN_CHANNELS[dataset]

        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
        self.not_done_prob_all = torch.zeros_like(self.query_all)
Exemplo n.º 8
0
 def __init__(self, args):
     self.dataset_loader = DataLoaderMaker.get_test_attacked_data(args.dataset, args.batch_size)
     self.total_images = len(self.dataset_loader.dataset)
     self.query_all = torch.zeros(self.total_images)
     self.correct_all = torch.zeros_like(self.query_all)  # number of images
     self.not_done_all = torch.zeros_like(self.query_all)  # always set to 0 if the original image is misclassified
     self.success_all = torch.zeros_like(self.query_all)
     self.success_query_all = torch.zeros_like(self.query_all)
     self.not_done_loss_all = torch.zeros_like(self.query_all)
     self.not_done_prob_all = torch.zeros_like(self.query_all)
Exemplo n.º 9
0
    def __init__(self,
                 dataset,
                 batch_size,
                 pixel_attack,
                 freq_dims,
                 stride,
                 order,
                 max_iters,
                 targeted,
                 target_type,
                 norm,
                 pixel_epsilon,
                 l2_bound,
                 linf_bound,
                 lower_bound=0.0,
                 upper_bound=1.0):
        """
            :param pixel_epsilon: perturbation limit according to lp-ball
            :param norm: norm for the lp-ball constraint
            :param lower_bound: minimum value data point can take in any coordinate
            :param upper_bound: maximum value data point can take in any coordinate
            :param max_crit_queries: max number of calls to early stopping criterion  per data poinr
        """
        assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
        self.pixel_epsilon = pixel_epsilon
        self.dataset = dataset
        self.norm = norm
        self.pixel_attack = pixel_attack
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.freq_dims = freq_dims
        self.stride = stride
        self.order = order
        self.linf_bound = linf_bound
        self.l2_bound = l2_bound
        # self.early_stop_crit_fct = lambda model, x, y: 1 - model(x).max(1)[1].eq(y)
        self.max_iters = max_iters
        self.targeted = targeted
        self.target_type = target_type

        self.data_loader = DataLoaderMaker.get_test_attacked_data(
            dataset, batch_size)
        self.total_images = len(self.data_loader.dataset)
        self.image_height = IMAGE_SIZE[dataset][0]
        self.image_width = IMAGE_SIZE[dataset][1]
        self.in_channels = IN_CHANNELS[dataset]

        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
Exemplo n.º 10
0
 def __init__(self, dataset_name, targeted):
     self.dataset_name = dataset_name
     self.num_classes = CLASS_NUM[self.dataset_name]
     self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
         dataset_name, 1)
     self.total_images = len(self.dataset_loader.dataset)
     self.targeted = targeted
     self.query_all = torch.zeros(self.total_images)
     self.correct_all = torch.zeros_like(self.query_all)  # number of images
     self.not_done_all = torch.zeros_like(
         self.query_all
     )  # always set to 0 if the original image is misclassified
     self.success_all = torch.zeros_like(self.query_all)
     self.success_query_all = torch.zeros_like(self.query_all)
Exemplo n.º 11
0
 def __init__(self, dataset_name, model, surrogate_model, targeted, target_type):
     self.dataset_name = dataset_name
     self.data_loader = DataLoaderMaker.get_test_attacked_data(args.dataset, 1)
     self.image_height = IMAGE_SIZE[self.dataset_name][0]
     self.image_width =IMAGE_SIZE[self.dataset_name][1]
     self.in_channels = IN_CHANNELS[self.dataset_name]
     self.model = model
     self.surrogate_model = surrogate_model
     self.model.cuda().eval()
     self.surrogate_model.cuda().eval()
     self.targeted = targeted # only support untargeted attack now
     self.target_type = target_type
     self.clip_min = 0.0
     self.clip_max = 1.0
Exemplo n.º 12
0
    def __init__(self, function, config, device):
        self.config = config
        self.batch_size = config['batch_size']
        self.function = function
        self.model = function.model
        self.device = device
        self.epsilon = self.config['epsilon']
        self.gp = attack_bayesian_EI.Attack(
            f=self,
            dim=4,
            max_evals=1000,
            verbose=True,
            use_ard=True,
            max_cholesky_size=2000,
            n_training_steps=30,
            device=device,
            dtype="float32",
        )
        self.query_limit = self.config['query_limit']
        self.max_iters = self.config['max_iters']
        self.init_iter = self.config["init_iter"]
        self.init_batch = self.config["init_batch"]
        self.memory_size = self.config["memory_size"]
        self.channels = self.config["channels"]
        self.image_height = self.config["image_height"]
        self.image_width = self.config["image_width"]
        self.gp_emptyX = torch.zeros((1, 4), device=device)
        self.gp_emptyfX = torch.zeros((1), device=device)
        self.local_forget_threshold = self.config['local_forget_threshold']
        self.lr = self.config['lr']

        self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
            args.dataset, args.batch_size)
        self.total_images = len(self.dataset_loader.dataset)
        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
        self.maximum_queries = self.config["max_queries"]
Exemplo n.º 13
0
    def __init__(self,
                 pop_size=5,
                 generations=1000,
                 cross_rate=0.7,
                 mutation_rate=0.001,
                 max_queries=2000,
                 epsilon=8. / 255,
                 iters=10,
                 ensemble_models=None,
                 targeted=False):
        self.loss_fn = nn.CrossEntropyLoss()
        self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
            args.dataset, 1)
        self.total_images = len(self.dataset_loader.dataset)
        # parameters about evolution algorithm
        self.pop_size = pop_size
        self.generations = generations
        self.cross_rate = cross_rate
        self.mutation_rate = mutation_rate
        # parameters about attack
        self.epsilon = epsilon
        self.clip_min = 0
        self.clip_max = 1
        # ensemble MI-FGSM parameters, use ensemble MI-FGSM attack generate adv as initial population
        self.ensemble_models = ensemble_models
        self.iters = iters
        self.targeted = targeted
        self.max_queries = max_queries
        self.idx = np.random.choice(np.arange(self.pop_size),
                                    size=2,
                                    replace=False)
        self.is_change = np.zeros(self.pop_size)
        self.pop_fitness = np.zeros(self.pop_size)

        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
Exemplo n.º 14
0
    def __init__(self, dataset_name, targeted):
        self.dataset_name = dataset_name
        self.num_classes = CLASS_NUM[self.dataset_name]
        self.dataset_loader = DataLoaderMaker.get_test_attacked_data(
            dataset_name, 1)

        log.info("label index dict data build begin")
        # if self.dataset_name == "TinyImageNet":
        #     self.candidate_loader = DataLoaderMaker.get_candidate_attacked_data(dataset_name, 1)
        #     self.dataset = self.candidate_loader.dataset
        # else:
        self.dataset = self.dataset_loader.dataset
        self.label_data_index_dict = self.get_label_dataset(self.dataset)
        log.info("label index dict data build over!")
        self.total_images = len(self.dataset_loader.dataset)
        self.targeted = targeted
        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
Exemplo n.º 15
0
             test_model_list_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/checkpoints/{}*.pth".format(
                 PY_ROOT, args.dataset, arch)
             test_model_list_path = list(glob.glob(test_model_list_path))
             if len(test_model_list_path
                    ) == 0:  # this arch does not exists in args.dataset
                 continue
             archs.append(arch)
 else:
     assert args.arch is not None
     archs = [args.arch]
 args.arch = ", ".join(archs)
 log.info('Command line is: {}'.format(' '.join(sys.argv)))
 log.info("Log file is written in {}".format(log_file_path))
 log.info('Called with args:')
 print_args(args)
 data_loader = DataLoaderMaker.get_test_attacked_data(args.dataset, 1)
 for arch in archs:
     if args.attack_defense:
         save_result_path = args.exp_dir + "/{}_{}_result.json".format(
             arch, args.defense_model)
     else:
         save_result_path = args.exp_dir + "/{}_result.json".format(arch)
     if os.path.exists(save_result_path):
         continue
     log.info("Begin attack {} on {}, result will be saved to {}".format(
         arch, args.dataset, save_result_path))
     if args.attack_defense:
         model = DefensiveModel(args.dataset,
                                arch,
                                no_grad=True,
                                defense_model=args.defense_model)
Exemplo n.º 16
0
                    arch)
        if os.path.exists(save_result_path):
            continue
        log.info("Begin attack {} on {}, result will be saved to {}".format(
            arch, args.dataset, save_result_path))

        if args.attack_defense:
            model = DefensiveModel(args.dataset,
                                   arch,
                                   no_grad=True,
                                   defense_model=args.defense_model)
        else:
            model = StandardModel(args.dataset, arch, no_grad=True)
        model.cuda()
        model.eval()
        dataset_loader = DataLoaderMaker.get_test_attacked_data(
            args.dataset, args.batch_size)
        success_all = torch.zeros(dataset_loader.dataset.__len__()).float()
        correct_all = torch.zeros(dataset_loader.dataset.__len__()).float()
        query_all = torch.zeros(dataset_loader.dataset.__len__()).float()
        for batch_idx, data_tuple in enumerate(dataset_loader):
            if args.dataset == "ImageNet":
                if model.input_size[-1] >= 299:
                    images, true_labels = data_tuple[1], data_tuple[2]
                else:
                    images, true_labels = data_tuple[0], data_tuple[2]
            else:
                images, true_labels = data_tuple[0], data_tuple[1]
            if args.targeted:
                target_labels = torch.fmod(true_labels + 1,
                                           CLASS_NUM[args.dataset])
            else:
Exemplo n.º 17
0
def main():
    args = get_args()
    with open(args.config) as config_file:
        state = json.load(config_file)["attack"][args.targeted]
        state = SimpleNamespace(**state)
    if args.save_prefix is not None:
        state.save_prefix = args.save_prefix
    if args.arch is not None:
        state.arch = args.arch
    if args.test_archs is not None:
        state.test_archs = args.test_archs
    state.OSP = args.OSP
    state.targeted = args.targeted

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    targeted_str = "untargeted" if not state.targeted else "targeted"
    if state.targeted:
        save_name = "{}/train_pytorch_model/TREMBA/{}_{}_generator.pth.tar".format(
            PY_ROOT, args.dataset, targeted_str)
    else:
        save_name = "{}/train_pytorch_model/TREMBA/{}_{}_generator.pth.tar".format(
            PY_ROOT, args.dataset, targeted_str)
    weight = torch.load(save_name, map_location=device)["state_dict"]
    data_loader = DataLoaderMaker.get_test_attacked_data(
        args.dataset, args.batch_size)
    encoder_weight = {}
    decoder_weight = {}
    for key, val in weight.items():
        if key.startswith('0.'):
            encoder_weight[key[2:]] = val
        elif key.startswith('1.'):
            decoder_weight[key[2:]] = val
    archs = []
    if args.test_archs:
        if args.dataset == "CIFAR-10" or args.dataset == "CIFAR-100":
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/{}/checkpoint.pth.tar".format(
                    PY_ROOT, args.dataset, arch)
                if os.path.exists(test_model_path):
                    archs.append(arch)
                else:
                    log.info(test_model_path + " does not exists!")
        elif args.dataset == "TinyImageNet":
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_list_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}*.pth.tar".format(
                    root=PY_ROOT, dataset=args.dataset, arch=arch)
                test_model_path = list(glob.glob(test_model_list_path))
                if test_model_path and os.path.exists(test_model_path[0]):
                    archs.append(arch)
        else:
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_list_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/checkpoints/{}*.pth".format(
                    PY_ROOT, args.dataset, arch)
                test_model_list_path = list(glob.glob(test_model_list_path))
                if len(test_model_list_path
                       ) == 0:  # this arch does not exists in args.dataset
                    continue
                archs.append(arch)
        args.arch = ",".join(archs)
    else:
        archs.append(args.arch)

    args.exp_dir = get_exp_dir_name(args.dataset, args.norm, args.targeted,
                                    args.target_type, args)
    for arch in archs:
        if args.attack_defense:
            save_result_path = args.exp_dir + "/{}_{}_result.json".format(
                arch, args.defense_model)
        else:
            save_result_path = args.exp_dir + "/{}_result.json".format(arch)
        if os.path.exists(save_result_path):
            continue
        if args.OSP:
            if state.source_model_name == "Adv_Denoise_Resnet152":
                source_model = resnet152_denoise()
                loaded_state_dict = torch.load((os.path.join(
                    '{}/train_pytorch_model/TREMBA'.format(PY_ROOT),
                    state.source_model_name + ".pth.tar")))
                source_model.load_state_dict(loaded_state_dict)
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                # FIXME 仍然要改改
                source_model = nn.Sequential(Normalize(mean, std),
                                             source_model)
                source_model.to(device)
                source_model.eval()

        if args.attack_defense:
            model = DefensiveModel(args.dataset,
                                   arch,
                                   no_grad=True,
                                   defense_model=args.defense_model)
        else:
            model = StandardModel(args.dataset, arch, no_grad=True)

        model.eval()
        encoder = ImagenetEncoder()
        decoder = ImagenetDecoder(args.dataset)
        encoder.load_state_dict(encoder_weight)
        decoder.load_state_dict(decoder_weight)
        model.to(device)
        encoder.to(device)
        encoder.eval()
        decoder.to(device)
        decoder.eval()
        F = Function(model, state.batch_size, state.margin,
                     CLASS_NUM[args.dataset], state.targeted)
        total_success = 0
        count_total = 0
        queries = []
        not_done = []
        correct_all = []

        for i, (images, labels) in enumerate(data_loader):
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images)
            correct = torch.argmax(logits, dim=1).eq(labels).item()
            correct_all.append(int(correct))
            if correct:
                if args.targeted:
                    if args.target_type == 'random':
                        target_labels = torch.randint(
                            low=0,
                            high=CLASS_NUM[args.dataset],
                            size=labels.size()).long().cuda()
                        invalid_target_index = target_labels.eq(labels)
                        while invalid_target_index.sum().item() > 0:
                            target_labels[
                                invalid_target_index] = torch.randint(
                                    low=0,
                                    high=logit.shape[1],
                                    size=target_labels[invalid_target_index].
                                    shape).long().cuda()
                            invalid_target_index = target_labels.eq(labels)
                    elif args.target_type == 'least_likely':
                        with torch.no_grad():
                            logit = model(images)
                        target_labels = logit.argmin(dim=1)
                    elif args.target_type == "increment":
                        target_labels = torch.fmod(
                            labels + 1, CLASS_NUM[args.dataset]).cuda()
                    labels = target_labels[0].item()
                else:
                    labels = labels[0].item()
                if args.OSP:
                    hinge_loss = MarginLossSingle(state.white_box_margin,
                                                  state.target)
                    images.requires_grad = True
                    latents = encoder(images)
                    for k in range(state.white_box_iters):
                        perturbations = decoder(latents) * state.epsilon
                        logits = source_model(
                            torch.clamp(images + perturbations, 0, 1))
                        loss = hinge_loss(logits, labels)
                        grad = torch.autograd.grad(loss, latents)[0]
                        latents = latents - state.white_box_lr * grad
                    with torch.no_grad():
                        success, adv, query_count = EmbedBA(
                            F, encoder, decoder, images[0], labels, state,
                            latents.view(-1))
                else:
                    with torch.no_grad():
                        success, adv, query_count = EmbedBA(
                            F, encoder, decoder, images[0], labels, state)
                not_done.append(1 - int(success))
                total_success += int(success)
                count_total += int(correct)
                if success:
                    queries.append(query_count)
                else:
                    queries.append(args.max_queries)

                log.info(
                    "image: {} eval_count: {} success: {} average_count: {} success_rate: {}"
                    .format(i, F.current_counts, success, F.get_average(),
                            float(total_success) / float(count_total)))
                F.new_counter()
            else:
                queries.append(0)
                not_done.append(1)
                log.info("The {}-th image is already classified incorrectly.".
                         format(i))
        correct_all = np.concatenate(correct_all, axis=0).astype(np.int32)
        query_all = np.array(queries).astype(np.int32)
        not_done_all = np.array(not_done).astype(np.int32)
        success = (1 - not_done_all) * correct_all
        success_query = success * query_all
        meta_info_dict = {
            "query_all":
            query_all.tolist(),
            "not_done_all":
            not_done_all.tolist(),
            "correct_all":
            correct_all.tolist(),
            "mean_query":
            np.mean(success_query[np.nonzero(success)[0]]).item(),
            "max_query":
            np.max(success_query[np.nonzero(success)[0]]).item(),
            "median_query":
            np.median(success_query[np.nonzero(success)[0]]).item(),
            "avg_not_done":
            np.mean(not_done_all[np.nonzero(correct_all)[0]].astype(
                np.float32)).item(),
            "args":
            vars(args)
        }

        with open(save_result_path, "w") as result_file_obj:
            json.dump(meta_info_dict, result_file_obj, sort_keys=True)
        log.info("Done, write stats info to {}".format(save_result_path))
Exemplo n.º 18
0
def main(args, result_dir_path):
    log.info('Loading %s model and test data' % args.dataset)
    dataset_loader = DataLoaderMaker.get_test_attacked_data(args.dataset, 1)
    if args.test_archs:
        archs = []
        if args.dataset == "CIFAR-10" or args.dataset == "CIFAR-100":
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/{}/checkpoint.pth.tar".format(
                    PY_ROOT, args.dataset, arch)
                if os.path.exists(test_model_path):
                    archs.append(arch)
                else:
                    log.info(test_model_path + " does not exists!")
        elif args.dataset == "TinyImageNet":
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_list_path = "{root}/train_pytorch_model/real_image_model/{dataset}@{arch}*.pth.tar".format(
                    root=PY_ROOT, dataset=args.dataset, arch=arch)
                test_model_path = list(glob.glob(test_model_list_path))
                if test_model_path and os.path.exists(test_model_path[0]):
                    archs.append(arch)
        else:
            for arch in MODELS_TEST_STANDARD[args.dataset]:
                test_model_list_path = "{}/train_pytorch_model/real_image_model/{}-pretrained/checkpoints/{}*.pth".format(
                    PY_ROOT, args.dataset, arch)
                test_model_list_path = list(glob.glob(test_model_list_path))
                if len(test_model_list_path
                       ) == 0:  # this arch does not exists in args.dataset
                    continue
                archs.append(arch)
    else:
        assert args.arch is not None
        archs = [args.arch]
    if args.attack_defense:
        meta_model_path = '{}/train_pytorch_model/meta_grad_regression/{}_without_resnet.pth.tar'.format(
            PY_ROOT, args.dataset)
    else:
        meta_model_path = '{}/train_pytorch_model/meta_grad_regression/{}.pth.tar'.format(
            PY_ROOT, args.dataset)
    assert os.path.exists(meta_model_path), "{} does not exist!".format(
        meta_model_path)
    meta_model = load_meta_model(meta_model_path)

    log.info("Load meta model from {}".format(meta_model_path))
    attack = MetaGradAttack(args,
                            norm=args.norm,
                            epsilon=args.epsilon,
                            targeted=args.targeted,
                            search_steps=args.binary_steps,
                            max_steps=args.maxiter,
                            use_log=not args.use_zvalue,
                            cuda=not args.no_cuda)

    for arch in archs:
        if args.attack_defense:
            model = DefensiveModel(args.dataset,
                                   arch,
                                   no_grad=True,
                                   defense_model=args.defense_model)
        else:
            model = StandardModel(args.dataset, arch, no_grad=True)

        model.cuda().eval()
        query_all = []
        not_done_all = []
        correct_all = []
        img_no = 0
        total_success = 0
        l2_total = 0.0
        avg_step = 0
        avg_time = 0
        avg_qry = 0
        if args.attack_defense:
            result_dump_path = result_dir_path + "/{}_{}_result.json".format(
                arch, args.defense_model)
        else:
            result_dump_path = result_dir_path + "/{}_result.json".format(arch)

        # if os.path.exists(result_dump_path):
        #     continue
        log.info("Begin attack {} on {}".format(arch, args.dataset))
        for i, data_tuple in enumerate(dataset_loader):
            if args.dataset == "ImageNet":
                if model.input_size[-1] >= 299:
                    img, true_labels = data_tuple[1], data_tuple[2]
                else:
                    img, true_labels = data_tuple[0], data_tuple[2]
            else:
                img, true_labels = data_tuple[0], data_tuple[1]
            args.init_size = model.input_size[-1]
            if img.size(-1) != model.input_size[-1]:
                img = F.interpolate(img,
                                    size=model.input_size[-1],
                                    mode='bilinear',
                                    align_corners=True)

            img, true_labels = img.to(0), true_labels.to(0)
            with torch.no_grad():
                pred_logit = model(img)
            pred_label = pred_logit.argmax(dim=1)
            correct = pred_label.eq(true_labels).detach().cpu().numpy().astype(
                np.int32)
            correct_all.append(correct)
            if pred_label[0].item() != true_labels[0].item():
                log.info(
                    "Skip wrongly classified image no. %d, original class %d, classified as %d"
                    % (i, pred_label.item(), true_labels.item()))
                query_all.append(0)
                not_done_all.append(
                    1)  # 原本就分类错误,not_done给1,假如99%都原本分类错误的话, avg_not_done = 99%
                continue
            img_no += 1
            timestart = time.time()
            meta_model_copy = copy.deepcopy(meta_model)
            if args.targeted:
                if args.target_type == 'random':
                    target_labels = torch.randint(
                        low=0,
                        high=CLASS_NUM[args.dataset],
                        size=true_labels.size()).long().cuda()
                    invalid_target_index = target_labels.eq(true_labels)
                    while invalid_target_index.sum().item() > 0:
                        target_labels[invalid_target_index] = torch.randint(
                            low=0,
                            high=pred_logit.shape[1],
                            size=target_labels[invalid_target_index].shape
                        ).long().cuda()
                        invalid_target_index = target_labels.eq(true_labels)
                elif args.target_type == 'least_likely':
                    target_labels = pred_logit.argmin(dim=1)
                elif args.target_type == "increment":
                    target_labels = torch.fmod(true_labels + 1,
                                               CLASS_NUM[args.dataset])
                else:
                    raise NotImplementedError('Unknown target_type: {}'.format(
                        args.target_type))
            else:
                target_labels = None
            target = true_labels if not args.targeted else target_labels
            adv, const, first_step, success_queries = attack.run(
                model, meta_model_copy, img, target)
            timeend = time.time()
            if len(adv.shape) == 3:
                adv = adv.reshape((1, ) + adv.shape)
            adv = torch.from_numpy(adv).permute(0, 3, 1,
                                                2).cuda()  # BHWC -> BCHW
            diff = (adv - img).detach().cpu().numpy()
            l2_distortion = np.sqrt(np.sum(np.square(diff))).item()
            with torch.no_grad():
                adv_pred_logit = model(adv)
                adv_pred_label = adv_pred_logit.argmax(dim=1)
            success = False
            if not args.targeted:  # target is true label
                if adv_pred_label[0].item() != target[0].item():
                    success = True
            else:
                if adv_pred_label[0].item() == target[0].item():
                    success = True
            if success_queries > args.max_queries:
                success = False
            if success:
                # (first_step-1)//args.finetune_intervalargs.update_pixels2+first_step
                # The first step is the iterations used that find the adversarial examples;
                # args.finetune_interval is the finetuning per iterations;
                # args.update_pixels is the queried pixels each iteration.
                # Currently, we find only f(x+h)-f(x) could estimate the gradient well, so args.update_pixels*1 in my updated codes.
                not_done_all.append(0)
                # only 1 query for i pixle, because the estimated function is f(x+h)-f(x)/h
                query_all.append(success_queries)
                total_success += 1
                l2_total += l2_distortion
                avg_step += first_step
                avg_time += timeend - timestart
                # avg_qry += (first_step-1)//args.finetune_interval*args.update_pixels*1+first_step
                avg_qry += success_queries
                log.info("Attack {}-th image: {}, query:{}".format(
                    i, "success", success_queries))
            else:
                not_done_all.append(1)
                query_all.append(args.max_queries)
                log.info("Attack {}-th image: {}, query:{}".format(
                    i, "fail", success_queries))
        model.cpu()
        if total_success != 0:
            log.info(
                "[STATS][L1] total = {}, time = {:.3f}, distortion = {:.5f}, avg_step = {:.5f},avg_query = {:.5f}, success_rate = {:.3f}"
                .format(img_no, avg_time / total_success,
                        l2_total / total_success, avg_step / total_success,
                        avg_qry / total_success,
                        total_success / float(img_no)))
        correct_all = np.concatenate(correct_all, axis=0).astype(np.int32)
        query_all = np.array(query_all).astype(np.int32)
        not_done_all = np.array(not_done_all).astype(np.int32)
        success = (1 - not_done_all) * correct_all
        success_query = success * query_all

        # query_all_bounded = query_all.copy()
        # not_done_all_bounded = not_done_all.copy()
        # out_of_bound_indexes = np.where(query_all_bounded > args.max_queries)[0]
        # if len(out_of_bound_indexes) > 0:
        #     not_done_all_bounded[out_of_bound_indexes] = 1
        # success_bounded = (1-not_done_all_bounded) * correct_all
        # success_query_bounded = success_bounded * query_all_bounded
        #
        # query_threshold_success_rate_bounded, query_success_rate_bounded = success_rate_and_query_coorelation(query_all_bounded, not_done_all_bounded)
        # success_rate_to_avg_query_bounded = success_rate_avg_query(query_all_bounded, not_done_all_bounded)

        meta_info_dict = {
            "query_all":
            query_all.tolist(),
            "not_done_all":
            not_done_all.tolist(),
            "correct_all":
            correct_all.tolist(),
            "mean_query":
            np.mean(success_query[np.nonzero(success)[0]]).item(),
            "max_query":
            np.max(success_query[np.nonzero(success)[0]]).item(),
            "median_query":
            np.median(success_query[np.nonzero(success)[0]]).item(),
            "avg_not_done":
            np.mean(
                not_done_all.astype(
                    np.float32)[np.nonzero(correct_all)[0]]).item(),

            # "mean_query_bounded_max_queries": np.mean(success_query_bounded[np.nonzero(success_bounded)[0]]).item(),
            # "max_query_bounded_max_queries": np.max(success_query_bounded[np.nonzero(success_bounded)[0]]).item(),
            # "median_query_bounded_max_queries": np.median(success_query_bounded[np.nonzero(success_bounded)[0]]).item(),
            # "avg_not_done_bounded_max_queries": np.mean(not_done_all_bounded.astype(np.float32)).item(),

            # "query_threshold_success_rate_dict_bounded": query_threshold_success_rate_bounded,
            # "query_success_rate_dict_bounded": query_success_rate_bounded,
            # "success_rate_to_avg_query_bounded": success_rate_to_avg_query_bounded,
            "args":
            vars(args)
        }
        with open(result_dump_path, "w") as result_file_obj:
            json.dump(meta_info_dict, result_file_obj, sort_keys=True)
        log.info("done, write stats info to {}".format(result_dump_path))
Exemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--arch', type=str, required=True, help="The arch used to generate adversarial images for testing")
    parser.add_argument("--gpu",type=str,required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument('--test_mode', default=0, type=int, choices=list(range(10)))
    # parser.add_argument('--model', default='res', type=str)
    parser.add_argument('--n_epoch', default=200, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--test_batch_size', default=10, type=int)
    parser.add_argument('--lambd', default=0.0001, type=float)
    parser.add_argument('--noise_dev', default=20.0, type=float)
    parser.add_argument('--Linfinity', default=8/255, type=float)
    parser.add_argument('--binary_threshold', default=0.5, type=float)
    parser.add_argument('--lr_mode', default=0, type=int)
    parser.add_argument('--test_interval', default=100, type=int)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument("--use_res_net",action="store_true")
    args = parser.parse_args()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cudnn.deterministic = True
    model_path = '{}/train_pytorch_model/adversarial_train/com_defend/{}@{}@epoch_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.n_epoch, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    set_log_file(os.path.dirname(model_path) + "/train_{}.log".format(args.dataset))
    log.info('Command line is: {}'.format(' '.join(sys.argv)))
    log.info('Called with args:')
    print_args(args)

    in_channels = IN_CHANNELS[args.dataset]
    # if args.use_res_net:
    #     if args.test_mode == 0:
    #         com_defender = ModelRes(in_channels=in_channels, com_disable=True,rec_disable=True)
    #         args.save_model = 'normal'
    #     elif args.test_mode == 1:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=1,n_rec=3,com_disable=False,rec_disable=True)
    #         args.save_model = '1_on_off'
    #     elif args.test_mode == 2:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=2,n_rec=3,com_disable=False,rec_disable=True)
    #         args.save_model = '2_on_off'
    #     elif args.test_mode == 3:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=3,n_rec=3,com_disable=False,rec_disable=True)
    #         args.save_model = '3_on_off'
    #     elif args.test_mode == 4:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=3,n_rec=1,com_disable=True,rec_disable=False)
    #         args.save_model = 'off_on_1'
    #     elif args.test_mode == 5:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=3,n_rec=2,com_disable=True,rec_disable=False)
    #         args.save_model = 'off_on_2'
    #     elif args.test_mode == 6:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=3,n_rec=3,com_disable=True,rec_disable=False)
    #         args.save_model = 'off_on_3'
    #     elif args.test_mode == 7:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=1,n_rec=1,com_disable=False,rec_disable=False)
    #         args.save_model = '1_1'
    #     elif args.test_mode == 8:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=2,n_rec=2,com_disable=False,rec_disable=False)
    #         args.save_model = '2_2'
    #     elif args.test_mode == 9:
    #         com_defender = ModelRes(in_channels=in_channels, n_com=3,n_rec=3,com_disable=False,rec_disable=False)
    #         args.save_model = '3_3'
    # else:
    com_defender = ComDefend(in_channels, args.noise_dev)
    args.save_model = "normal_network"
    log.info('test mode: {}, model name: {}'.format(args.test_mode, args.save_model))

    if args.gpu is not None:
        log.info("Use GPU: {} for training".format(args.gpu))
    log.info("=> creating model '{}'".format(args.arch))

    log.info("after train, model will be saved to {}".format(model_path))
    com_defender.cuda()
    cudnn.benchmark = True
    train_loader = DataLoaderMaker.get_imgid_img_label_data_loader(args.dataset, args.batch_size, True, seed=1234)
    test_attack_dataset_loader = DataLoaderMaker.get_test_attacked_data(args.dataset, args.batch_size)
    log.info("Begin generate the adversarial examples.")
    target_model, adv_images, adv_true_labels = test_attack(args.Linfinity, args.arch, args.dataset,
                                                      test_attack_dataset_loader)  # 这些图片被用来验证

    log.info("Generate adversarial examples done!")
    best_acc = torch.zeros(1)
    for epoch in range(0, args.n_epoch):
        train(args, train_loader, com_defender, epoch, target_model, adv_images, adv_true_labels,best_acc, model_path)
def main(target_model,  result_dump_path, args):
    # make model(s)
    log.info('Initializing target model {} on {}'.format(args.arch, args.dataset))

    ref_models = OrderedDict()
    for i, ref_arch in enumerate(args.ref_arch):
        params = dict()
        params['train_data'] = args.ref_arch_train_data
        params['epoch'] = args.ref_arch_epoch
        log.info('Initializing ref model {} on {} ({} of {}), params: {}'.format(
            ref_arch, args.dataset, i + 1, len(args.ref_arch), params))
        ref_models[ref_arch] = StandardModel(args.dataset, ref_arch, no_grad=False, is_subspace_attack_ref_arch=True,
                                             ref_arch_train_data=args.ref_arch_train_data, ref_arch_epoch=args.ref_arch_epoch).cuda().eval()
    log.info('All target_models have been initialized, including 1 target model and {} ref target_models'.format(len(args.ref_arch)))

    # make loader
    loader = DataLoaderMaker.get_test_attacked_data(args.dataset, args.batch_size)
    total_images = len(loader.dataset)
    # make operators

    prior_step = eval('{}_prior_step'.format(args.prior_update))
    image_step = eval('{}_image_step'.format(args.norm_type))
    proj_step = eval('{}_proj_step'.format(args.norm_type))

    if args.delta_size > 0:
        # resize
        upsampler = lambda x: F.interpolate(x, size=target_model.input_size[-1], mode='bilinear', align_corners=True)  # 这就是低维度的sub space放大回来
        downsampler = lambda x: F.interpolate(x, size=args.delta_size, mode='bilinear', align_corners=True)
    else:
        # no resize, upsampler = downsampler = identity
        upsampler = downsampler = lambda x: x # CIFAR-10不用缩尺寸, ImageNet需要缩尺寸

    # make loss function
    loss_func = eval('{}_loss'.format(args.loss))

    # init arrays for saving results
    query_all = torch.zeros(total_images)
    correct_all = torch.zeros_like(query_all)
    not_done_all = torch.zeros_like(query_all)  # always set to 0 if the original image is misclassified
    success_all = torch.zeros_like(query_all)
    success_query_all = torch.zeros_like(query_all)
    not_done_loss_all = torch.zeros_like(query_all)
    not_done_prob_all = torch.zeros_like(query_all)

    # make directory for saving results
    result_dirname = osp.join(args.exp_dir, 'results')
    os.makedirs(result_dirname, exist_ok=True)

    # fixed direction for illustration experiments
    if args.num_fix_direction > 0:
        if len(args.ref_arch) == 0:
            # fixed random direction
            assert args.dataset == 'CIFAR-10'
            state = np.random.get_state()
            np.random.seed(args.fix_direction_seed)
            fix_direction = np.random.randn(3072, *target_model.input_size)[:args.num_fix_direction]
            np.random.set_state(state)
            fix_direction = np.ascontiguousarray(fix_direction)
            fix_direction = torch.FloatTensor(fix_direction).to(device)
        else:
            # fixed gradient direction (calculated at clean inputs)
            assert args.num_fix_direction == len(args.ref_arch)

    for batch_index, data_tuple in enumerate(loader):
        if args.dataset == "ImageNet":
            if target_model.input_size[-1] >= 299:
                image, label = data_tuple[1], data_tuple[2]
            else:
                image, label = data_tuple[0], data_tuple[2]
        else:
            image, label = data_tuple[0], data_tuple[1]

        if image.size(-1) != target_model.input_size[-1]:
            image = F.interpolate(image, size=target_model.input_size[-1], mode='bilinear')
        assert image.max().item() <= 1
        assert image.min().item() >= 0

        # move image and label to device
        image = image.to(device)
        label = label.to(device)
        adv_image = image.clone()

        # get logit and prob
        logit = target_model(image)
        adv_logit = logit.clone()
        pred = logit.argmax(dim=1)

        # choose target classes for targeted attack
        if args.targeted:
            if args.target_type == 'random':
                target = torch.randint(low=0, high=logit.shape[1], size=label.shape).long().to(device)
                # make sure target is not equal to label for any example
                invalid_target_index = target.eq(label)
                while invalid_target_index.sum().item() > 0:
                    target[invalid_target_index] = torch.randint(low=0, high=logit.shape[1],
                                                                 size=target[invalid_target_index].shape).long().to(
                        device)
                    invalid_target_index = target.eq(label)
            elif args.target_type == 'least_likely':
                target = logit.argmin(dim=1)
            elif args.target_type == "increment":
                target = torch.fmod(label+1, CLASS_NUM[args.dataset])
            else:
                raise NotImplementedError('Unknown target_type: {}'.format(args.target_type))

        else:
            target = None

        # init masks and selectors
        correct = pred.eq(label).float()
        query = torch.zeros(args.batch_size).to(device)
        not_done = correct.clone()
        selected = torch.arange(batch_index * args.batch_size, (batch_index + 1) * args.batch_size)

        # init prior
        if args.delta_size > 0:
            prior = torch.zeros(args.batch_size, target_model.input_size[0], args.delta_size, args.delta_size).to(device)
        else:
            prior = torch.zeros(args.batch_size, *target_model.input_size).to(device)

        # perform attack
        for step_index in range(args.max_queries // 2):
            # increase query counts
            query = query + 2 * not_done  # not_done = 预测与gt label相等的图片个数

            # calculate drop probability
            if step_index < args.ref_arch_drop_grow_iter:
                drop = args.ref_arch_init_drop
            else:
                # args.ref_arch_max_drop 默认等于0.5
                drop = args.ref_arch_max_drop - \
                    (args.ref_arch_max_drop - args.ref_arch_init_drop) * \
                    np.exp(-(step_index - args.ref_arch_drop_grow_iter) * args.ref_arch_drop_gamma)

            # finite difference for gradient estimation
            if len(ref_models) > 0:
                # select ref model to calculate gradient
                #
                selected_ref_arch_index = torch.randint(low=0, high=len(ref_models), size=(1,)).long().item()

                # get original model logit's grad
                adv_logit = adv_logit.detach()
                adv_logit.requires_grad = True
                loss = loss_func(adv_logit, label, target).mean()
                logit_grad = torch.autograd.grad(loss, [adv_logit])[0]

                # calculate gradient for all ref target_models
                def calc_ref_grad(adv_image_, ref_model_, drop_=0):
                    adv_image_ = adv_image_.detach()
                    adv_image_.requires_grad = True
                    if adv_image_.grad:
                        adv_image_.grad[:] = 0.
                    ref_model_.zero_grad()

                    # assign dropout probability
                    ref_model_.drop = drop_  # 这个可以进模型代码看看,drop怎么做的

                    # forward ref model
                    if ref_model_.input_size != model.input_size:
                        ref_logit_ = ref_model_(F.interpolate(adv_image_, size=ref_model_.input_size[-1],
                                                              mode='bilinear', align_corners=True))
                    else:
                        ref_logit_ = ref_model_(adv_image_)

                    # backward ref model using logit_grad from the victim model
                    ref_grad_ = torch.autograd.grad(ref_logit_, [adv_image_], grad_outputs=[logit_grad])[0]
                    ref_grad_ = downsampler(ref_grad_)  # 高维度缩小

                    # compute dl/dv
                    if args.fix_grad:
                        if prior.view(prior.shape[0], -1).norm(dim=1).min().item() > 0:
                            # -1 / ||v|| ** 3 (||v|| ** 2 dL/dv - v(v^T dL/dv))
                            g1 = norm(prior) ** 2 * ref_grad_
                            g2 = prior * (prior * ref_grad_).sum(dim=(1, 2, 3)).view(-1, 1, 1, 1)
                            ref_grad_ = g1 - g2
                    return ref_grad_ / norm(ref_grad_)  # 拿到direction

                # calculate selected ref model's gradient
                if args.num_fix_direction == 0:
                    # 随机选择一个模型,输入adv_image,得到梯度.这个梯度是否准确不知道,因为是随机选择的模型,不如用网络生成
                    direction = calc_ref_grad(adv_image, list(ref_models.values())[selected_ref_arch_index], drop_=drop)
                else:
                    # for illustrate experiment in rebuttal
                    assert args.loss == 'cw'
                    assert drop == 0
                    direction = calc_ref_grad(image, list(ref_models.values())[selected_ref_arch_index], drop_=drop)

            else:
                # use random search direction solely
                if args.num_fix_direction > 0:
                    # use fixed direction (for illustration experiments)
                    if len(args.ref_arch) == 0:
                        # fixed random direction
                        # fix_direction.shape: [num_fix_direction, C, H, W]
                        # coeff.shape: [num_Fix_direction, N]
                        coeff = torch.randn(args.num_fix_direction, prior.shape[0]).to(device)
                        direction = (fix_direction.view(fix_direction.shape[0], 1, *fix_direction.shape[1:]) *
                                     coeff.view(coeff.shape[0], coeff.shape[1], 1, 1, 1)).sum(dim=0)
                    else:
                        # fixed gradient direction (calculated at clean inputs) for rebuttal
                        # direction has already been calculated
                        assert direction.shape[0] == image.shape[0]
                else:
                    direction = torch.randn_like(prior)

            # normalize search direction
            direction = direction / norm(direction)  #这个方向是用随机选择一个model,估计出来的梯度,可以换成用meta预测方向。ground truth用prior来给最终loss,由于meta训练已知模型的话梯度非常好给,所以可以用梯度累积后的prior做ground truth

            q1 = upsampler(prior + args.exploration * direction)  # 这两个upsampler和downsampler,使用在ImageNet大图起作用
            q2 = upsampler(prior - args.exploration * direction)  # 方向来自于随机选取的model,可以改为meta learning设计一下
            l1 = loss_func(target_model(adv_image + args.fd_eta * q1 / norm(q1)), label, target)  # 需要查询
            l2 = loss_func(target_model(adv_image + args.fd_eta * q2 / norm(q2)), label, target)
            grad = (l1 - l2) / (args.fd_eta * args.exploration)  # 需要2次查询,grad是论文Alg1第11行左边那个Delta_t,用于更新梯度的一个量
            # 这段抄袭的bandit attack,但是把原来的exp_noise换成了direction
            grad = grad.view(-1, 1, 1, 1) * direction     # grad.shape == direction.shape == prior.shape ?= image.shape 这就是精髓所在

            # update prior,其实prior就是梯度,因为后一个prior和前一个有联系,类似贝叶斯的prior
            prior = prior_step(prior, grad, args.prior_lr)  # 用grad更新prior,piror最后更新到图像上。prior就是图像梯度

            # extract grad from prior
            grad = upsampler(prior)  # prior相当于梯度

            # update adv_image (correctly classified images only)
            adv_image = image_step(adv_image, grad * correct.view(-1, 1, 1, 1), args.image_lr)
            adv_image = proj_step(image, args.epsilon, adv_image)
            adv_image = torch.clamp(adv_image, 0, 1)

            # update statistics
            with torch.no_grad():
                adv_logit = target_model(adv_image)
            adv_pred = adv_logit.argmax(dim=1)
            adv_prob = F.softmax(adv_logit, dim=1)
            adv_loss = loss_func(adv_logit, label, target)
            if args.targeted:
                not_done = not_done * (1 - adv_pred.eq(target)).float()
            else:
                not_done = not_done * adv_pred.eq(label).float()
            success = (1 - not_done) * correct  # currently done, originally correct --> success
            success_query = success * query
            not_done_loss = adv_loss * not_done
            not_done_prob = adv_prob[torch.arange(args.batch_size), label] * not_done
            # log
            log.info('Attacking image {} - {} / {}, step {}, max query {}'.format(
                batch_index * args.batch_size, (batch_index + 1) * args.batch_size,
                total_images, step_index + 1, int(query.max().item())
            ))
            log.info('        correct: {:.4f}'.format(correct.mean().item()))
            log.info('       not_done: {:.4f}'.format(not_done.mean().item()))
            log.info('      fd_scalar: {:.4f}'.format((l1 - l2).mean().item()))
            log.info('           drop: {:.4f}'.format(drop))
            if success.sum().item() > 0:
                log.info('     mean_query: {:.4f}'.format(success_query[success.byte()].mean().item()))
                log.info('   median_query: {:.4f}'.format(success_query[success.byte()].median().item()))
            if not_done.sum().item() > 0:
                log.info('  not_done_loss: {:.4f}'.format(not_done_loss[not_done.byte()].mean().item()))
                log.info('  not_done_prob: {:.4f}'.format(not_done_prob[not_done.byte()].mean().item()))


            # early break if all succeed
            if not not_done.byte().any():
                log.info('  image {} - {} all succeed in step {}, break'.format(
                    batch_index * args.batch_size, (batch_index + 1) * args.batch_size, step_index
                ))
                break

        # save results to arrays
        # 下面这段代码统计最终的统计量,比如success_query_all
        for key in ['query', 'correct', 'not_done',
                    'success', 'success_query', 'not_done_loss', 'not_done_prob']:
            value_all = eval('{}_all'.format(key))
            value = eval(key)
            value_all[selected] = value.detach().float().cpu()

    # log statistics for $total_images images
    log.info('Attack finished ({} images)'.format(total_images))
    log.info('        avg correct: {:.4f}'.format(correct_all.mean().item()))
    log.info('       avg not_done: {:.4f}'.format(not_done_all.mean().item()))
    if success_all.sum().item() > 0:
        log.info('     avg mean_query: {:.4f}'.format(success_query_all[success_all.byte()].mean().item()))
        log.info('   avg median_query: {:.4f}'.format(success_query_all[success_all.byte()].median().item()))
    if not_done_all.sum().item() > 0:
        log.info('  avg not_done_loss: {:.4f}'.format(not_done_loss_all[not_done_all.byte()].mean().item()))
        log.info('  avg not_done_prob: {:.4f}'.format(not_done_prob_all[not_done_all.byte()].mean().item()))
    log.info('Saving results to {}'.format(result_dump_path))

    meta_info_dict = {"avg_correct": correct_all.mean().item(),
                      "avg_not_done": not_done_all[correct_all.byte()].mean().item(),
                      "mean_query": success_query_all[success_all.byte()].mean().item(),
                      "median_query": success_query_all[success_all.byte()].median().item(),
                      "max_query": success_query_all[success_all.byte()].max().item(),
                      "correct_all": correct_all.detach().cpu().numpy().astype(np.int32).tolist(),
                      "not_done_all": not_done_all.detach().cpu().numpy().astype(np.int32).tolist(),
                      "query_all": query_all.detach().cpu().numpy().astype(np.int32).tolist(),
                      "not_done_loss": not_done_loss_all[not_done_all.byte()].mean().item(),
                      "not_done_prob": not_done_prob_all[not_done_all.byte()].mean().item(),
                      "args":vars(args)}
    with open(result_dump_path, "w") as result_file_obj:
            json.dump(meta_info_dict, result_file_obj, sort_keys=True)
    # save results to disk
    log.info('Done')
Exemplo n.º 21
0
    def __init__(self,
                 dataset,
                 order,
                 r,
                 rho,
                 mom,
                 n_samples,
                 targeted,
                 target_type,
                 norm,
                 epsilon,
                 low_dim,
                 lower_bound=0.0,
                 upper_bound=1.0,
                 max_queries=10000):
        """
            :param epsilon: perturbation limit according to lp-ball
            :param norm: norm for the lp-ball constraint
            :param lower_bound: minimum value data point can take in any coordinate
            :param upper_bound: maximum value data point can take in any coordinate
            :param max_queries: max number of calls to model per data point
            :param max_crit_queries: max number of calls to early stopping criterion  per data poinr
        """
        assert norm in ['linf', 'l2'], "{} is not supported".format(norm)
        self.epsilon = epsilon
        self.norm = norm
        self.max_queries = max_queries
        self.order = order
        self.r = r

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        # self.early_stop_crit_fct = lambda model, x, y: 1 - model(x).max(1)[1].eq(y)

        self.targeted = targeted
        self.target_type = target_type
        self.dataset = dataset
        self.data_loader = DataLoaderMaker.get_test_attacked_data(dataset, 1)
        self.total_images = len(self.data_loader.dataset)
        self.image_height = IMAGE_SIZE[dataset][0]
        self.image_width = IMAGE_SIZE[dataset][1]
        self.in_channels = IN_CHANNELS[dataset]
        self.low_dim = low_dim
        self.query_all = torch.zeros(self.total_images)
        self.correct_all = torch.zeros_like(self.query_all)  # number of images
        self.not_done_all = torch.zeros_like(
            self.query_all
        )  # always set to 0 if the original image is misclassified
        self.success_all = torch.zeros_like(self.query_all)
        self.success_query_all = torch.zeros_like(self.query_all)
        self.not_done_prob_all = torch.zeros_like(self.query_all)
        if dataset.startswith("CIFAR"):  # TODO 实验不同的参数效果
            self.freq_dim = 11  # 28
            self.stride = 7
        elif dataset == "TinyImageNet":
            self.freq_dim = 15
            self.stride = 7
        elif dataset == "ImageNet":
            self.freq_dim = 28
            self.stride = 7
        self.mom = mom  # default 1 not add
        self.n_samples = n_samples  # number of samples per iteration (1 by default), not the number of images to be evaluated.
        self.rho = rho
        self.construct_random_matrix()