def _parse_and_update_results( self, batch_summarizers: Dict[str, Summarizer] ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: results: Dict[str, Union[MetricResultType, Dict[ str, MetricResultType]]] = { ORIGINAL_KEY: self._format_summary( cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary))["mean"] } self.summary_results[ORIGINAL_KEY].update( self.metric_aggregator(results[ORIGINAL_KEY])) for attack_key in self.attacks: attack = self.attacks[attack_key] attack_results = self._format_summary( cast(Union[Dict, List], batch_summarizers[attack.name].summary)) results[attack.name] = attack_results if len(attack_results) == 1: key = next(iter(attack_results)) if attack.name not in self.summary_results: self.summary_results[attack.name] = Summarizer( [stat() for stat in self.aggregate_stats]) self.summary_results[attack.name].update( self.metric_aggregator(attack_results[key])) else: for key in attack_results: summary_key = f"{attack.name} {key.title()} Attempt" if summary_key not in self.summary_results: self.summary_results[summary_key] = Summarizer( [stat() for stat in self.aggregate_stats]) self.summary_results[summary_key].update( self.metric_aggregator(attack_results[key])) return results
def test_single_input(self): size = (2, 3) summarizer = Summarizer(stats=CommonStats()) for _ in range(10): attrs = torch.randn(size) summarizer.update(attrs) summ = summarizer.summary self.assertIsNotNone(summ) self.assertTrue(isinstance(summ, dict)) for k in summ: self.assertTrue(summ[k].size() == size)
def test_stats_random_data(self): N = 1000 BIG_VAL = 100000 values = list(get_values(lo=-BIG_VAL, hi=BIG_VAL, n=N)) stats_to_test = [ Mean(), Var(), Var(order=1), StdDev(), StdDev(order=1), Min(), Max(), Sum(), MSE(), ] stat_names = [ "mean", "variance", "sample_variance", "std_dev", "sample_std_dev", "min", "max", "sum", "mse", ] gt_fns = [ np.mean, np.var, lambda x: np.var(x, ddof=1), np.std, lambda x: np.std(x, ddof=1), np.min, np.max, np.sum, lambda x: np.sum((x - np.mean(x))**2), ] for stat, name, gt in zip(stats_to_test, stat_names, gt_fns): summ = Summarizer([stat]) for x in values: summ.update(torch.tensor(x, dtype=torch.float64)) actual = torch.from_numpy(np.array(gt(values))) stat_val = summ.summary[name] # rounding errors is a serious issue (moreso for MSE) assertTensorAlmostEqual(self, stat_val, actual, delta=0.005)
def test_stats_random_data(self): N = 1000 BIG_VAL = 100000 _values = list(get_values(lo=-BIG_VAL, hi=BIG_VAL, n=N)) values = torch.tensor(_values, dtype=torch.float64) stats_to_test = [ Mean(), Var(), Var(order=1), StdDev(), StdDev(order=1), Min(), Max(), Sum(), MSE(), ] stat_names = [ "mean", "variance", "sample_variance", "std_dev", "sample_std_dev", "min", "max", "sum", "mse", ] gt_fns = [ torch.mean, lambda x: torch.var(x, unbiased=False), lambda x: torch.var(x, unbiased=True), lambda x: torch.std(x, unbiased=False), lambda x: torch.std(x, unbiased=True), torch.min, torch.max, torch.sum, lambda x: torch.sum((x - torch.mean(x))**2), ] for stat, name, gt in zip(stats_to_test, stat_names, gt_fns): summ = Summarizer([stat]) actual = gt(values) for x in values: summ.update(x) stat_val = summ.summary[name] # rounding errors is a serious issue (moreso for MSE) assertTensorAlmostEqual(self, stat_val, actual, delta=0.005)
def test_multi_input(self): size1 = (10, 5, 5) size2 = (3, 5) summarizer = Summarizer(stats=CommonStats()) for _ in range(10): a1 = torch.randn(size1) a2 = torch.randn(size2) summarizer.update((a1, a2)) summ = summarizer.summary self.assertIsNotNone(summ) self.assertTrue(len(summ) == 2) self.assertTrue(isinstance(summ[0], dict)) self.assertTrue(isinstance(summ[1], dict)) for k in summ[0]: self.assertTrue(summ[0][k].size() == size1) self.assertTrue(summ[1][k].size() == size2)
def test_multi_dim(self): x1 = torch.tensor([1.0, 2.0, 3.0, 4.0]) x2 = torch.tensor([2.0, 1.0, 2.0, 4.0]) x3 = torch.tensor([3.0, 3.0, 1.0, 4.0]) summarizer = Summarizer([Mean(), Var()]) summarizer.update(x1) assertTensorAlmostEqual(self, summarizer.summary["mean"], x1, delta=0.05, mode="max") assertTensorAlmostEqual( self, summarizer.summary["variance"], torch.zeros_like(x1), delta=0.05, mode="max", ) summarizer.update(x2) assertTensorAlmostEqual( self, summarizer.summary["mean"], torch.tensor([1.5, 1.5, 2.5, 4]), delta=0.05, mode="max", ) assertTensorAlmostEqual( self, summarizer.summary["variance"], torch.tensor([0.25, 0.25, 0.25, 0]), delta=0.05, mode="max", ) summarizer.update(x3) assertTensorAlmostEqual( self, summarizer.summary["mean"], torch.tensor([2, 2, 2, 4]), delta=0.05, mode="max", ) assertTensorAlmostEqual( self, summarizer.summary["variance"], torch.tensor([2.0 / 3.0, 2.0 / 3.0, 2.0 / 3.0, 0]), delta=0.05, mode="max", )
def test_var_defin(self): """ Variance is avg squared distance to mean. Thus it should be positive. This test is to ensure this is the case. To test it, we will we make a skewed distribution leaning to one end (either very large or small values). We will also compare to numpy and ensure it is approximately the same. This is assuming numpy is correct, for which it should be. """ SMALL_VAL = -10000 BIG_VAL = 10000 AMOUNT_OF_SMALLS = [100, 10] AMOUNT_OF_BIGS = [10, 100] for sm, big in zip(AMOUNT_OF_SMALLS, AMOUNT_OF_BIGS): summ = Summarizer([Var()]) values = [] for _ in range(sm): values.append(SMALL_VAL) summ.update(torch.tensor(SMALL_VAL, dtype=torch.float64)) for _ in range(big): values.append(BIG_VAL) summ.update(torch.tensor(BIG_VAL, dtype=torch.float64)) actual_var = np.var(values) actual_var = torch.from_numpy(np.array(actual_var)) var = summ.summary["variance"] assertTensorAlmostEqual(self, var, actual_var) self.assertTrue((var > 0).all())
def test_div0(self): summarizer = Summarizer([Var(), Mean()]) summ = summarizer.summary self.assertIsNone(summ) summarizer.update(torch.tensor(10)) summ = summarizer.summary assertTensorAlmostEqual(self, summ["mean"], 10) assertTensorAlmostEqual(self, summ["variance"], 0) summarizer.update(torch.tensor(10)) summ = summarizer.summary assertTensorAlmostEqual(self, summ["mean"], 10) assertTensorAlmostEqual(self, summ["variance"], 0)
def test_multi_dim(self): x1 = torch.tensor([1.0, 2.0, 3.0, 4.0]) x2 = torch.tensor([2.0, 1.0, 2.0, 4.0]) x3 = torch.tensor([3.0, 3.0, 1.0, 4.0]) summarizer = Summarizer([Mean(), Var()]) summarizer.update(x1) assertArraysAlmostEqual(summarizer.summary["mean"], x1) assertArraysAlmostEqual(summarizer.summary["variance"], torch.zeros_like(x1)) summarizer.update(x2) assertArraysAlmostEqual(summarizer.summary["mean"], torch.tensor([1.5, 1.5, 2.5, 4])) assertArraysAlmostEqual(summarizer.summary["variance"], torch.tensor([0.25, 0.25, 0.25, 0])) summarizer.update(x3) assertArraysAlmostEqual(summarizer.summary["mean"], torch.tensor([2, 2, 2, 4])) assertArraysAlmostEqual( summarizer.summary["variance"], torch.tensor([2.0 / 3.0, 2.0 / 3.0, 2.0 / 3.0, 0]), )
def evaluate( self, inputs: Any, additional_forward_args: Any = None, perturbations_per_eval: int = 1, **kwargs, ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: r""" Evaluate model and attack performance on provided inputs Args: inputs (any): Input for which attack metrics are computed. It can be provided as a tensor, tuple of tensors, or any raw input type (e.g. PIL image or text string). This input is provided directly as input to preproc function as well as any attack applied before preprocessing. If no pre-processing function is provided, this input is provided directly to the main model and all attacks. additional_forward_args (any, optional): If the forward function requires additional arguments other than the preprocessing outputs (or inputs if preproc_fn is None), this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. For a tensor, the first dimension of the tensor must correspond to the number of examples. For all other types, the given argument is used for all forward evaluations. Default: None perturbations_per_eval (int, optional): Allows perturbations of multiple attacks to be grouped and evaluated in one call of forward_fn Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. In order to apply this functionality, the output of preproc_fn (or inputs itself if no preproc_fn is provided) must be a tensor or tuple of tensors. Default: 1 kwargs (any, optional): Additional keyword arguments provided to metric function as well as selected attacks based on chosen additional_args Returns: - **attack results** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]: Dictionary containing attack results for provided batch. Maps attack name to dictionary, containing best-case, worst-case and average-case results for attack. Dictionary contains keys "mean", "max" and "min" when num_attempts > 1 and only "mean" for num_attempts = 1, which contains the (single) metric result for the attack attempt. An additional key of 'Original' is included with metric results without any perturbations. Examples:: >>> def accuracy_metric(model_out: Tensor, targets: Tensor): >>> return torch.argmax(model_out, dim=1) == targets).float() >>> attack_metric = AttackComparator(model=resnet18, metric=accuracy_metric, preproc_fn=normalize) >>> random_rotation = transforms.RandomRotation() >>> jitter = transforms.ColorJitter() >>> attack_metric.add_attack(random_rotation, "Random Rotation", >>> num_attempts = 5) >>> attack_metric.add_attack((jitter, "Jitter", num_attempts = 1) >>> attack_metric.add_attack(FGSM(resnet18), "FGSM 0.1", num_attempts = 1, >>> apply_before_preproc=False, >>> attack_kwargs={epsilon: 0.1}, >>> additional_args=["targets"]) >>> for images, labels in dataloader: >>> batch_results = attack_metric.evaluate(inputs=images, targets=labels) """ additional_forward_args = _format_additional_forward_args( additional_forward_args ) expanded_additional_args = ( _expand_additional_forward_args( additional_forward_args, perturbations_per_eval ) if perturbations_per_eval > 1 else additional_forward_args ) preproc_input = None if self.preproc_fn is not None: preproc_input = self.preproc_fn(inputs) else: preproc_input = inputs input_list = [preproc_input] key_list = [ORIGINAL_KEY] batch_summarizers = {ORIGINAL_KEY: Summarizer([Mean()])} if ORIGINAL_KEY not in self.summary_results: self.summary_results[ORIGINAL_KEY] = Summarizer( [stat() for stat in self.aggregate_stats] ) def _check_and_evaluate(input_list, key_list): if len(input_list) == perturbations_per_eval: self._evaluate_batch( input_list, expanded_additional_args, key_list, batch_summarizers, kwargs, ) return [], [] return input_list, key_list input_list, key_list = _check_and_evaluate(input_list, key_list) for attack_key in self.attacks: attack = self.attacks[attack_key] if attack.num_attempts > 1: stats = [stat() for stat in self.batch_stats] else: stats = [Mean()] batch_summarizers[attack.name] = Summarizer(stats) additional_attack_args = {} for key in attack.additional_args: if key not in kwargs: warnings.warn( f"Additional sample arg {key} not provided for {attack_key}" ) else: additional_attack_args[key] = kwargs[key] for _ in range(attack.num_attempts): if attack.apply_before_preproc: attacked_inp = attack.attack_fn( inputs, **additional_attack_args, **attack.attack_kwargs ) preproc_attacked_inp = ( self.preproc_fn(attacked_inp) if self.preproc_fn else attacked_inp ) else: preproc_attacked_inp = attack.attack_fn( preproc_input, **additional_attack_args, **attack.attack_kwargs ) input_list.append(preproc_attacked_inp) key_list.append(attack.name) input_list, key_list = _check_and_evaluate(input_list, key_list) if len(input_list) > 0: final_add_args = _expand_additional_forward_args( additional_forward_args, len(input_list) ) self._evaluate_batch( input_list, final_add_args, key_list, batch_summarizers, kwargs ) return self._parse_and_update_results(batch_summarizers)