def attack(self, image, label):
        self.function.new_counter()
        self.block_size = self.config['block_size']["{}x{}".format(
            self.image_height, self.image_width)]
        self.noise = self.noise_init()
        self.image = image.clone()
        self.label = label
        _, self.loss = self.function(perturb_image(image, self.noise), label)
        if self.loss < 0:
            image = perturb_image(self.image, self.noise)
            return image, True
        upper_left = [0, 0]
        lower_right = [self.image_height, self.image_width]
        blocks = self.split_block(self.image, upper_left, lower_right,
                                  self.block_size)

        while True:
            # Run local search algorithm on the mini-batch
            self.gp_normalize = torch.tensor([
                self.image_height / self.block_size,
                self.image_width / self.block_size, self.channels, 1
            ],
                                             dtype=torch.float32,
                                             device=self.device)
            for iter in range(self.max_iters):
                success = self.local_bayes(blocks, "positive")
                if success or self.function.current_counts > self.query_limit:
                    image = perturb_image(self.image, self.noise)
                    return image, success

                success = self.local_bayes(blocks, "negative")
                if success or self.function.current_counts > self.query_limit:
                    image = perturb_image(self.image, self.noise)
                    return image, success

                if self.config['print_log']:
                    log.info(
                        "Block size: {}, loss: {:.4f}, num queries: {}".format(
                            self.block_size, self.loss.item(),
                            self.function.current_counts))

            if self.block_size >= 2:
                if self.block_size % 2 != 0:
                    temp_block_size = self.block_size // 2
                    if temp_block_size < 10:
                        for t in range(1, 10):
                            if self.image_height % t == 0:
                                self.block_size = t
                    else:
                        while self.image_height % temp_block_size != 0:
                            temp_block_size += 1
                        self.block_size = temp_block_size
                else:
                    self.block_size //= 2
                blocks = self.split_block(self.image, upper_left, lower_right,
                                          self.block_size)
            if self.function.current_counts > self.maximum_queries:
                image = perturb_image(self.image, self.noise)
                return image, False
    def get_loss(self, indices):
        indices = indices * self.gp_normalize
        batch_size = self.batch_size
        num_batches = int(math.ceil(len(indices) / batch_size))
        losses = torch.zeros(len(indices), device=self.device)
        for ibatch in range(num_batches):
            bstart = ibatch * batch_size
            bend = min(bstart + batch_size, len(indices))
            images = self.image.unsqueeze(0).repeat(bend - bstart, 1, 1, 1)

            for i, index in enumerate(indices[bstart:bend]):
                noise_flip = change_noise(self.noise, index, self.block_size,
                                          self.sigma, self.epsilon)
                images[i] = perturb_image(self.image, noise_flip)
            logit, loss_p = self.function(images, self.label)
            for i, index in enumerate(indices[bstart:bend]):
                noise_flip = change_noise(self.noise, index, self.block_size,
                                          -self.sigma, self.epsilon)
                images[i] = perturb_image(self.image, noise_flip)
            logit, loss_n = self.function(images, self.label)
            losses[bstart:bend] = torch.min(loss_n, loss_p)

        return losses - self.loss
    def local_bayes(self, blocks, direction):
        select_blocks = []
        for i, block in enumerate(blocks):
            x, y, c = block[0:3]
            x *= self.block_size
            y *= self.block_size
            if direction == "positive" and self.noise[
                    c, x, y] < 0 or direction == "negative" and self.noise[
                        c, x, y] > 0:
                select_blocks.append(block)

        blocks = torch.tensor(select_blocks,
                              dtype=torch.float32,
                              device=self.device)
        init_batch_size = max(blocks.size(0) // self.init_batch, 5)
        init_iteration = self.init_iter
        if blocks.size(0) < 2:
            return False
        if init_batch_size * init_iteration > blocks.size(0):
            if blocks.size(0) // init_iteration < 2:
                init_iteration = blocks.size(0) // 2
                init_batch_size = 2
            else:
                init_batch_size = blocks.size(0) // init_iteration
        init_iteration = init_batch_size * (init_iteration - 1)
        self.gp.init(blocks / self.gp_normalize,
                     n_init=init_batch_size,
                     batch_size=1,
                     iteration=init_iteration)

        self.gp.X_pool = blocks / self.gp_normalize

        memory_size = int(len(self.gp.X) * self.memory_size)
        priority_X = torch.arange(0, len(self.gp.X)).to(self.gp.X.device)
        priority = torch.tensor(self.gp.X.size(0)).to(priority_X.device)

        local_forget_threshold = self.local_forget_threshold[self.block_size]
        for i in range(blocks.size(0)):
            training_steps = 1
            x_cand, y_cand, self.gp.hypers = self.gp.create_candidates(
                self.gp.X,
                self.gp.fX,
                self.gp.X_pool,
                n_training_steps=training_steps,
                hypers=self.gp.hypers,
                sample_number=1)
            block, self.gp.X_pool = self.gp.select_candidates(x_cand,
                                                              y_cand,
                                                              get_loss=False)
            block = block[0] * self.gp_normalize
            if i >= blocks.size(0) // 2 and y_cand.min() > -1e-4:
                return False

            noise = flip_noise(self.noise, block, self.block_size)
            query_image = perturb_image(self.image, noise)
            logit, loss = self.function(query_image, self.label)

            if loss < 0:
                self.loss = loss
                return True

            if self.function.current_counts > self.query_limit:
                return False

            if self.config['print_log']:
                log.info(
                    "queries {}, new loss {:4f}, old loss {:4f}, gaussian size {}"
                    .format(self.function.current_counts, loss.item(),
                            self.loss.item(), len(self.gp.X)))

            if loss < self.loss:
                self.noise = noise.clone()
                self.loss = loss

                diff = (self.gp.X * self.gp_normalize -
                        block)[:, 0:2].abs().max(dim=1)[0]
                index = diff > (local_forget_threshold + 0.5)
                self.gp.X = self.gp.X[index]
                self.gp.fX = self.gp.fX[index]
                priority_X = priority_X[index]

                if len(priority_X) >= memory_size:
                    index = torch.argmin(priority_X)
                    priority_X = torch.cat(
                        (priority_X[:index], priority_X[index + 1:]))
                    self.gp.X = torch.cat(
                        (self.gp.X[:index], self.gp.X[index + 1:]), dim=0)
                    self.gp.fX = torch.cat(
                        (self.gp.fX[:index], self.gp.fX[index + 1:]), dim=0)

                if len(self.gp.X_pool) == 0:
                    break

                if self.gp.X.size(0) <= 1:
                    new_index = random.randint(0, len(self.gp.X_pool) - 1)
                    new_block = self.gp.X_pool[new_index] * self.gp_normalize

                    query_image = perturb_image(
                        self.image,
                        flip_noise(self.noise, new_block, self.block_size))
                    _, query_loss = self.function(query_image, self.label)

                    self.gp.X = torch.cat(
                        (self.gp.X,
                         (new_block / self.gp_normalize).unsqueeze(0)),
                        dim=0)
                    self.gp.fX = torch.cat(
                        (self.gp.fX, query_loss - self.loss), dim=0)

                    priority_X = torch.cat((priority_X, priority.unsqueeze(0)),
                                           dim=0)
                    priority += 1
            else:
                diff = (self.gp.X - block / self.gp_normalize).abs().sum(dim=1)
                min_diff, history_index = torch.min(diff, dim=0)
                if min_diff < 1e-5:
                    update_index = history_index
                elif len(priority_X) < memory_size:
                    update_index = len(priority_X)
                    self.gp.X = torch.cat((self.gp.X, self.gp_emptyX), dim=0)
                    self.gp.fX = torch.cat((self.gp.fX, self.gp_emptyfX),
                                           dim=0)
                    priority_X = torch.cat((priority_X, priority.unsqueeze(0)),
                                           dim=0)
                else:
                    update_index = torch.argmin(priority_X)

                self.gp.X[update_index] = block / self.gp_normalize
                self.gp.fX[update_index] = loss - self.loss
                priority_X[update_index] = priority
                priority += 1
            if self.function.current_counts > self.maximum_queries:
                return False
        return False