Ejemplo n.º 1
0
    def forward(self, x):
        # Note: We do not use actual masking because
        # that would require using boolean indexing, which
        # causes a CUDA synchronization (causing major slowdowns)

        if self.training or self.always_linear.requires_grad or self.always_zero.requires_grad:
            raise NotImplementedError(
                'MaskedReLU is not designed for training.')

        output = torch.relu(x)

        # Expand the batch dimension to match x
        expand_size = [len(x)] + [-1] * len(x.shape[1:])
        always_zero = self.always_zero.unsqueeze(0).expand(*expand_size)
        always_linear = self.always_linear.unsqueeze(0).expand(*expand_size)

        # always_zero masking
        output = utils.fast_boolean_choice(output,
                                           torch.zeros_like(x),
                                           always_zero,
                                           reshape=False)

        # always_linear masking
        output = utils.fast_boolean_choice(output,
                                           x,
                                           always_linear,
                                           reshape=False)

        return output
Ejemplo n.º 2
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        batch_size = len(x)

        with torch.no_grad():
            best_adversarials = x.clone()
            best_distances = torch.ones(
                (batch_size, ), device=x.device) * MAX_DISTANCE

            for _ in range(self.count):
                noise = torch.rand(x.shape, device=x.device)

                scaled_noise = (noise * 2 - 1) * self.eps
                adversarials = torch.clamp(x + scaled_noise,
                                           min=self.clip_min,
                                           max=self.clip_max)

                assert adversarials.shape == x.shape

                successful = self.successful(adversarials, y).detach()

                distances = utils.adversarial_distance(x, adversarials, self.p)
                better_distance = distances < best_distances

                assert len(better_distance) == len(best_adversarials) == len(
                    best_distances)

                best_adversarials = utils.fast_boolean_choice(
                    best_adversarials, adversarials,
                    successful & better_distance)
                best_distances = utils.fast_boolean_choice(
                    best_distances, distances, successful & better_distance)

            return best_adversarials
Ejemplo n.º 3
0
    def perturb(self, x, y=None, **kwargs):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        # Create a tracker that will be only used for this batch
        tracker = BestSampleTracker(x, y, self.p, self.targeted)
        self.wrapped_model.tracker = tracker

        last_adversarials = self.inner_attack(x, y=y, **kwargs)

        # In case the last adversarials were not tested
        self.wrapped_model(last_adversarials)

        # If the wrapper failed to find some adversarials, use the
        # last ones
        final_adversarials = utils.fast_boolean_choice(
            last_adversarials, tracker.best, tracker.found_adversarial)

        self.wrapped_model.tracker = None

        return final_adversarials
Ejemplo n.º 4
0
    def forward(self, x, active_mask=None, filter_=None):
        if self.tracker is None:
            raise RuntimeError('No best sample tracker set.')

        # Don't detach here: attacks might require the gradients
        outputs = self.model(x)

        relevant_labels = self.tracker.labels
        relevant_genuines = self.tracker.genuines
        relevant_best_distances = self.tracker.best_distances
        relevant_found_adversarial = self.tracker.found_adversarial

        if active_mask is None:
            assert len(x) == len(self.tracker.genuines)
        else:
            assert len(x) == torch.count_nonzero(active_mask)
            # Boolean indexing causes a CUDA sync, which is why we do it only
            # if absolutely necessary
            relevant_labels = relevant_labels[active_mask]
            relevant_genuines = relevant_genuines[active_mask]
            relevant_best_distances = relevant_best_distances[active_mask]
            relevant_found_adversarial = relevant_found_adversarial[
                active_mask]

            # x doesn't need to be masked, since len(x) == torch.count_nonzero(active_mask)

        with torch.no_grad():
            adversarial_labels = torch.argmax(outputs, dim=1)
            if self.tracker.targeted:
                successful = torch.eq(adversarial_labels, relevant_labels)
            else:
                successful = ~torch.eq(adversarial_labels, relevant_labels)

            distances = utils.adversarial_distance(relevant_genuines, x,
                                                   self.tracker.p)
            better_distance = distances < relevant_best_distances

            # Replace only if successful and with a better distance
            replace = successful & (better_distance |
                                    (~relevant_found_adversarial))

            # filter_ restricts updates to only some samples
            if filter_ is not None:
                replace &= filter_

            new_found_adversarial = relevant_found_adversarial | replace

            if active_mask is None:
                self.tracker.best = utils.fast_boolean_choice(
                    self.tracker.best, x, replace)
                self.tracker.best_distances = utils.fast_boolean_choice(
                    self.tracker.best_distances, distances, replace)
                self.tracker.found_adversarial = new_found_adversarial
            else:
                # A masked replacement requires a different function
                utils.replace_active(x, self.tracker.best, active_mask,
                                     replace)
                utils.replace_active(distances, self.tracker.best_distances,
                                     active_mask, replace)
                self.tracker.found_adversarial[
                    active_mask] = new_found_adversarial

        return outputs
Ejemplo n.º 5
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        best_adversarials = x.clone()
        batch_size = len(x)

        eps_lower_bound = torch.ones(
            (batch_size, ), device=x.device) * self.min_eps
        eps_upper_bound = torch.ones(
            (batch_size, ), device=x.device) * self.max_eps
        best_distances = torch.ones(
            (batch_size), device=x.device) * MAX_DISTANCE

        initial_search_eps = eps_upper_bound.clone()
        for _ in range(self.eps_initial_search_steps):
            adversarials = self.perturb_standard(x, y,
                                                 initial_search_eps).detach()

            assert adversarials.shape == x.shape

            successful = self.successful(adversarials, y).detach()

            distances = utils.adversarial_distance(x, adversarials, self.p)
            better_distance = distances < best_distances

            replace = successful & better_distance

            best_adversarials = utils.fast_boolean_choice(
                best_adversarials, adversarials, replace)
            best_distances = utils.fast_boolean_choice(best_distances,
                                                       distances, replace)

            # Success: Reduce the upper bound
            eps_upper_bound = utils.fast_boolean_choice(
                eps_upper_bound, initial_search_eps, successful)

            # Reduce eps, regardless of the success
            initial_search_eps = initial_search_eps * self.eps_initial_search_factor

        for _ in range(self.eps_binary_search_steps):
            eps = (eps_lower_bound + eps_upper_bound) / 2
            adversarials = self.perturb_standard(x, y, eps).detach()
            successful = self.successful(adversarials, y).detach()

            distances = utils.adversarial_distance(x, adversarials, self.p)
            better_distance = distances < best_distances
            replace = successful & better_distance

            best_adversarials = utils.fast_boolean_choice(
                best_adversarials, adversarials, replace)
            best_distances = utils.fast_boolean_choice(best_distances,
                                                       distances, replace)

            # Success: Reduce the upper bound
            eps_upper_bound = utils.fast_boolean_choice(
                eps_upper_bound, eps, successful)

            # Failure: Increase the lower bound
            eps_lower_bound = utils.fast_boolean_choice(
                eps_lower_bound, eps, ~successful)

        assert best_adversarials.shape == x.shape

        return best_adversarials
Ejemplo n.º 6
0
    def perturb(self, x, y=None):
        x, y = self._verify_and_process_inputs(x, y)

        # Initialization
        if y is None:
            y = self._get_predicted_label(x)

        x = replicate_input(x)
        batch_size = len(x)
        final_adversarials = x.clone()

        # An array of booleans that stores which samples have not converged
        # yet
        active = torch.ones((batch_size, ), dtype=torch.bool, device=x.device)

        initial_const = self.initial_const
        taus = torch.ones((batch_size, ), device=x.device) * self.initial_tau

        # The previous adversarials. This is used to perform a "warm start"
        # during optimisation
        prev_adversarials = x.clone()

        max_tau = self.initial_tau

        i = 0

        while max_tau > self.min_tau:
            new_adversarials = self._run_attack(x, y, initial_const, taus,
                                                prev_adversarials.clone(),
                                                active).detach()

            # Store the adversarials for the next iteration,
            # even if they failed
            prev_adversarials = new_adversarials

            adversarial_outputs = self._outputs(new_adversarials,
                                                filter_=active)
            successful = self._successful(adversarial_outputs, y).detach()

            # If the Linf distance is lower than tau and the adversarial
            # is successful, use it as the new tau
            linf_distances = torch.max(torch.abs(new_adversarials -
                                                 x).flatten(1),
                                       dim=1)[0]
            linf_lower = linf_distances < taus

            taus = utils.fast_boolean_choice(taus, linf_distances,
                                             linf_lower & successful)

            # Save the remaining adversarials
            replace = successful

            if not self.update_inactive:
                replace = replace & active

            final_adversarials = utils.fast_boolean_choice(
                final_adversarials, new_adversarials, replace)

            taus *= self.tau_factor
            max_tau = taus.max()

            if self.reduce_const:
                initial_const /= 2

            # Drop failed samples or with a low tau
            low_tau = taus <= self.min_tau
            drop = low_tau | (~successful)
            active = active & (~drop)

            if self.tau_check != 0 and (i + 1) % self.tau_check == 0:
                # Causes an implicit sync point
                if not active.any():
                    break

            i += 1

        return final_adversarials
Ejemplo n.º 7
0
    def _run_attack(self, x, y, initial_const, taus, prev_adversarials,
                    active):
        assert len(x) == len(taus)
        batch_size = len(x)
        computed_adversarials = x.clone().detach()

        if self.warm_start:
            starting_atanh = self._get_arctanh_x(prev_adversarials.clone())
        else:
            starting_atanh = self._get_arctanh_x(x.clone())

        modifiers = torch.nn.Parameter(torch.zeros_like(starting_atanh))

        # An array of booleans that stores which samples have not converged
        # yet
        optimizer = optim.Adam([modifiers], lr=self.learning_rate)

        const = initial_const

        j = 0
        stop_search = False

        while (not stop_search) and const < self.max_const:
            # We add an extra iteration because adversarials are
            # not saved until the next iteration
            for k in range(self.max_iterations + 1):
                # Note: unlike the CPU version, the CUDA version updates and calls the model
                # on all samples, including inactive ones. However, the filter_ parameter is designed
                # to force best_sample to only update active samples. This is counter-productive, but
                # it's necessary in order to have consistent CPU and CUDA implementations
                outputs, losses = self._outputs_and_loss(x,
                                                         modifiers,
                                                         starting_atanh,
                                                         y,
                                                         const,
                                                         taus,
                                                         filter_=active)

                adversarials = tanh_rescale(starting_atanh + modifiers,
                                            self.clip_min,
                                            self.clip_max).detach()

                replace = torch.ones((batch_size, ),
                                     dtype=torch.bool,
                                     device=x.device)

                if not self.update_inactive:
                    replace = replace & active

                computed_adversarials = utils.fast_boolean_choice(
                    computed_adversarials, adversarials, replace)

                # Update the modifiers
                total_loss = torch.sum(losses)
                #total_loss = torch.sum(losses[active]) # Temp
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                # If early aborting is enabled, drop successful
                # samples with a small loss (the current adversarials
                # are saved regardless of whether they are dropped)
                if self.abort_early:
                    successful = self._successful(outputs, y).detach()
                    small_loss = losses < SMALL_LOSS_COEFFICIENT * const

                    active = active & ~(successful & small_loss)

                    if self.inner_check != 0 and (k +
                                                  1) % self.inner_check == 0:
                        # Causes an implicit sync point
                        if not active.any():
                            # Break from both loops
                            stop_search = True
                            break

            if stop_search:
                break

            if self.abort_early and self.const_check != 0 and (
                    j + 1) % self.const_check == 0:
                # Causes an implicit sync point
                if not active.any():
                    break

            # Give more weight to the output loss
            const *= self.const_factor

        return computed_adversarials