def test_efficiency_property(self): seq_features = torch.arange(3) # [0, 1, 2] distribution = Uniform(0, 1) characteristic_fn_scores = { frozenset(subset): [distribution.sample().item()] for subset in powerset(seq_features.numpy()) } characteristic_fn_scores[frozenset({})] = [0.0] def characteristic_fn(batch_seq_features): results = [] for seq_features in batch_seq_features.numpy(): results_key = frozenset(list(seq_features)) results.append(characteristic_fn_scores[results_key]) return torch.tensor(results, dtype=torch.float32) attributor = CharacteristicFunctionExampleShapleyAttributor( seq_features, characteristic_fn=characteristic_fn, iterations=1, subset_sampler=ExhaustiveSubsetSampler(), n_classes=1, ) shapley_values, scores = attributor.run() assert_array_equal( shapley_values.sum(dim=0).numpy(), characteristic_fn_scores[frozenset(list(seq_features.numpy()))], )
def make_shapley_attributor(self): return OnlineShapleyAttributor( self.mlps, self.priors, n_classes=self.n_classes, subset_sampler=ExhaustiveSubsetSampler(), )
def test_against_results_are_the_same_as_the_multiscale_model(self): sampler = ExhaustiveSubsetSampler() multiscale_model = MultiscaleModel(self.mlps, softmax=False, sampler=sampler, save_intermediate=True) recursive_multiscale_model = RecursiveMultiscaleModel( self.mlps, sampler=sampler, save_intermediate=True, count_n_evaluations=False, ) for n_video_frames in range(1, len(self.mlps) + 8): example = torch.randn(n_video_frames, self.input_dim) with torch.no_grad(): multiscale_model_results = multiscale_model( example.clone()).numpy() with torch.no_grad(): recursive_multiscale_model_results = recursive_multiscale_model( example.clone()).numpy() np.testing.assert_allclose( recursive_multiscale_model_results, multiscale_model_results, err_msg= f"Failure comparing scores for a {n_video_frames} frame input", rtol=1e-4, )
def __init__( self, single_scale_models: Union[List[nn.Module], nn.ModuleList], softmax: bool = False, save_intermediate: bool = False, sampler: Optional[SubsetSampler] = None, ): super().__init__() self.single_scale_models = nn.ModuleList([model for model in single_scale_models]) if sampler is None: self.sampler = ExhaustiveSubsetSampler() else: self.sampler = sampler self.softmax = softmax self.save_intermediate = save_intermediate self.intermediates = None
def test_scores_against_multiscale_model(self): multiscale_model = MultiscaleModel(self.mlps, sampler=ExhaustiveSubsetSampler()) multiscale_model.eval() shapley_attributor = self.make_shapley_attributor() n_video_frames = 5 example = torch.randn(n_video_frames, self.input_dim) with torch.no_grad(): model_scores = multiscale_model(example) shaley_values, attributor_scores = shapley_attributor.explain(example) np.testing.assert_allclose( attributor_scores.mean(axis=0).numpy(), model_scores.numpy(), rtol=1e-5 )
def __init__( self, characteristic_fn: Callable[[torch.Tensor], torch.Tensor], n_classes: int, subset_sampler: Optional[SubsetSampler] = None, device: torch.device = None, ): self.characteristic_fn = characteristic_fn self.n_classes = n_classes if subset_sampler is None: subset_sampler = ExhaustiveSubsetSampler(device=device) self.subset_sampler = subset_sampler self.device = device self.last_attributor = None
def test_exhaustive_sampling_against_naive_implementation(self): online_attributor = OnlineShapleyAttributor( self.mlps, self.priors, self.n_classes, subset_sampler=ExhaustiveSubsetSampler() ) n_video_frames = len(self.mlps) + 2 example = torch.randn(n_video_frames, self.input_dim) online_shapley_values = online_attributor.explain(example)[0].numpy() naive_attributor = NaiveShapleyAttributor(self.mlps, self.priors, self.n_classes) naive_shapley_values = naive_attributor.explain(example)[0].numpy() np.testing.assert_allclose( online_shapley_values, naive_shapley_values, atol=1e-7, rtol=1e-7, verbose=True, )
def __init__( self, video: torch.Tensor, characteristic_fn: Callable[[torch.Tensor], torch.Tensor], iterations: int, n_classes: int, subset_sampler: Optional[SubsetSampler] = None, device: torch.device = None, characteristic_fn_args: Optional[Tuple[Any]] = None, characteristic_fn_kwargs: Optional[Dict[str, Any]] = None, ): if characteristic_fn_args is None: characteristic_fn_args = [] if characteristic_fn_kwargs is None: characteristic_fn_kwargs = {} self.characteristic_fn = characteristic_fn self.characteristic_fn_args = characteristic_fn_args self.characteristic_fn_kwargs = characteristic_fn_kwargs self.sequence_features = video self.n_classes = n_classes self.n_elements = len(video) self.n_iterations = iterations if subset_sampler is None: subset_sampler = ExhaustiveSubsetSampler(device=device) self.subset_sampler = subset_sampler self.device = device self.n_scales = self.n_elements + 1 self.summed_scores = torch.zeros( (self.n_iterations, 2, self.n_scales, self.n_elements, self.n_classes), device=device, dtype=torch.float32, ) self.n_summed_scores = torch.zeros( (self.n_iterations, 2, self.n_scales, self.n_elements), device=device, dtype=torch.long, )
def get_subset_sampler(args, device): if args.approximate: return ConstructiveRandomSampler( max_samples=args.approximate_max_samples_per_scale, device=device) return ExhaustiveSubsetSampler(device=device)
def make_sampler(self, data): return ExhaustiveSubsetSampler()
class RecursiveMultiscaleModel(nn.Module): r""" Implements the multiscale model defined as: .. math:: f(X) = \mathbb{E}_s \left[ \mathbb{E}_{\substack{X' \subseteq X \\ |X'| = s}} [f_s(X')] \right] But rather than computing it in this fashion, it computes it in a bottom fashion so that :math:`f(X')` for subsets :math:`X'` is computed and used to produce the output for the next scale up. This way we get the outputs `f(X')` as a side effect of computing `f(X)` for free. This is useful for Shapely Value analysis where we need these intermediate values. This is accomplished by reformulating the above into a recurrence: .. math:: f(X) = \begin{cases} \mathbb{E}_{\substack{X' \subseteq X \\ |X'| = |X| - 1}}[f(X')] & |X| \geq n_{\max{}} \\ |X|^{-1} ( f_{|X|}(X) + (|X| - 1) \mathbb{E}_{\substack{X' \subset X \\ |X'| = |X| - 1}}[f( X')] & \text{otherwise} \end{cases} """ def __init__( self, single_scale_models: Union[List[nn.Module], nn.ModuleList], save_intermediate: bool = False, sampler: Optional[SubsetSampler] = None, count_n_evaluations: bool = False, softmax: bool = False, ): super().__init__() self.single_scale_models = nn.ModuleList([model for model in single_scale_models]) self.softmax = softmax if sampler is None: self.sampler = ExhaustiveSubsetSampler() else: self.sampler = sampler self.save_intermediate = save_intermediate self.intermediates = None self.count_n_evaluations = count_n_evaluations def forward(self, sequence: torch.Tensor) -> torch.Tensor: """ Args: sequence: Example of shape :math:`(T, C)` where :math:`T` is the number of elements in the sequence and :math:`C` is the channel size. Returns: Class scores of shape `(C',)` where :math:`C'` is the number of classes. """ if self.save_intermediate: self.intermediates = defaultdict(lambda: dict()) sequence_len = sequence.shape[0] previous_scores = torch.zeros( (0, self.single_scale_models[0].model[-1].out_features), dtype=torch.float, device=sequence.device, ) previous_subsequence_idx = torch.tensor( [[]], dtype=torch.long, device=sequence.device ) previous_n_evaluations = torch.zeros( (1,), dtype=torch.float32, device=sequence.device ) try: self.sampler.reset() except AttributeError: pass for scale_idx in range(sequence_len): subsequence_len = scale_idx + 1 current_subsequence_idxs = self.sampler.sample( sequence_len, subsequence_len ) subsequences = sequence[current_subsequence_idxs] subset_relations = compute_subset_relations( current_subsequence_idxs, previous_subsequence_idx ) current_n_evaluations = ( masked_mean(subset_relations, previous_n_evaluations) ) + 1 if scale_idx < len(self.single_scale_models): single_scale_model = self.single_scale_models[scale_idx] current_scores = single_scale_model(subsequences) if self.softmax: current_scores = F.softmax(current_scores, dim=-1) else: current_scores = masked_mean(subset_relations, previous_scores) if self.save_intermediate: self.intermediates[scale_idx]["scores"] = current_scores.cpu().numpy() if sequence_len >= 2 and 0 < scale_idx < len(self.single_scale_models): if self.count_n_evaluations: current_scores.add_( other=masked_mean( subset_relations, previous_scores * previous_n_evaluations[:, None], ) ).div_(current_n_evaluations[:, None]) else: current_scores.add_( alpha=scale_idx, other=masked_mean(subset_relations, previous_scores), ).div_(scale_idx + 1) if self.save_intermediate: self.intermediates[scale_idx][ "ensembled_scores" ] = current_scores.cpu().numpy() self.intermediates[scale_idx][ "subsequence_idxs" ] = current_subsequence_idxs.cpu().numpy() self.intermediates[scale_idx][ "current_n_evaluations" ] = current_n_evaluations.cpu().numpy() previous_scores = current_scores previous_subsequence_idx = current_subsequence_idxs previous_n_evaluations = current_n_evaluations return current_scores.mean(dim=0) def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs ) self.sampler.device = device return super().to(*args, **kwargs)
class MultiscaleModel(nn.Module): r""" Implements the multiscale model defined as: .. math:: f(X) = \mathbb{E}_s \left[ \mathbb{E}_{\substack{X' \subseteq X \\ |X'| = s}} [f_s(X')] \right] """ def __init__( self, single_scale_models: Union[List[nn.Module], nn.ModuleList], softmax: bool = False, save_intermediate: bool = False, sampler: Optional[SubsetSampler] = None, ): super().__init__() self.single_scale_models = nn.ModuleList([model for model in single_scale_models]) if sampler is None: self.sampler = ExhaustiveSubsetSampler() else: self.sampler = sampler self.softmax = softmax self.save_intermediate = save_intermediate self.intermediates = None def forward(self, sequence: torch.Tensor) -> torch.Tensor: """ Args: sequence: Example of shape :math:`(T, C)` where :math:`T` is the number of elements in the sequence and :math:`C` is the channel size. Returns: Class scores of shape `(C',)` where :math:`C'` is the number of classes. """ try: self.sampler.reset() except AttributeError: pass if self.save_intermediate: self.intermediates = defaultdict(lambda: dict()) sequence_len = sequence.shape[0] scores = None n_scales = min(sequence_len, len(self.single_scale_models)) for scale_idx in range(n_scales): subsequence_len = scale_idx + 1 current_subsequence_idxs = self.sampler.sample( sequence_len, subsequence_len ) subsequences = sequence[current_subsequence_idxs] if scale_idx < len(self.single_scale_models): single_scale_model = self.single_scale_models[scale_idx] current_scores = single_scale_model(subsequences) if self.softmax: current_scores = F.softmax(current_scores, dim=-1) if scores is None: scores = current_scores.mean(dim=0) else: scores.add_(current_scores.mean(dim=0)) if self.save_intermediate: self.intermediates[scale_idx]["ensembled_scores"] = scores.cpu().numpy() self.intermediates[scale_idx]["scores"] = current_scores.cpu().numpy() self.intermediates[scale_idx][ "subsequence_idxs" ] = current_subsequence_idxs scores.div_(n_scales) return scores def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs ) self.sampler.device = device return super().to(*args, **kwargs)
def instantiate(self): return ExhaustiveSubsetSampler()