def _expand_target( target: TargetType, n_steps: int, expansion_type: ExpansionTypes = ExpansionTypes.repeat, ) -> TargetType: if isinstance(target, list): if expansion_type == ExpansionTypes.repeat: return target * n_steps elif expansion_type == ExpansionTypes.repeat_interleave: expanded_target = [] for i in target: expanded_target.extend([i] * n_steps) return cast(Union[List[Tuple[int, ...]], List[int]], expanded_target) else: raise NotImplementedError( "Currently only `repeat` and `repeat_interleave`" " expansion_types are supported") elif isinstance(target, torch.Tensor) and torch.numel(target) > 1: if expansion_type == ExpansionTypes.repeat: return torch.cat([target] * n_steps, dim=0) elif expansion_type == ExpansionTypes.repeat_interleave: return target.repeat_interleave(n_steps, dim=0) else: raise NotImplementedError( "Currently only `repeat` and `repeat_interleave`" " expansion_types are supported") return target
def _select_targets(output: Tensor, target: TargetType) -> Tensor: if target is None: return output num_examples = output.shape[0] dims = len(output.shape) device = output.device if isinstance(target, (int, tuple)): return _verify_select_column(output, target) elif isinstance(target, torch.Tensor): if torch.numel(target) == 1 and isinstance(target.item(), int): return _verify_select_column(output, cast(int, target.item())) elif len(target.shape) == 1 and torch.numel(target) == num_examples: assert dims == 2, "Output must be 2D to select tensor of targets." return torch.gather(output, 1, target.reshape(len(output), 1)) else: raise AssertionError( "Tensor target dimension %r is not valid. %r" % (target.shape, output.shape)) elif isinstance(target, list): assert len( target ) == num_examples, "Target list length does not match output!" if isinstance(target[0], int): assert dims == 2, "Output must be 2D to select tensor of targets." return torch.gather( output, 1, torch.tensor(target, device=device).reshape(len(output), 1)) elif isinstance(target[0], tuple): return torch.stack([ output[(i, ) + cast(Tuple, targ_elem)] for i, targ_elem in enumerate(target) ]) else: raise AssertionError("Target element type in list is not valid.") else: raise AssertionError("Target type %r is not valid." % target)
def _batched_generator( inputs: TensorOrTupleOfTensorsGeneric, additional_forward_args: Any = None, target_ind: TargetType = None, internal_batch_size: Union[None, int] = None, ) -> Iterator[Tuple[Tuple[Tensor, ...], Any, TargetType]]: """ Returns a generator which returns corresponding chunks of size internal_batch_size for both inputs and additional_forward_args. If batch size is None, generator only includes original inputs and additional args. """ assert internal_batch_size is None or ( isinstance(internal_batch_size, int) and internal_batch_size > 0 ), "Batch size must be greater than 0." inputs = _format_tensor_into_tuples(inputs) additional_forward_args = _format_additional_forward_args(additional_forward_args) num_examples = inputs[0].shape[0] # TODO Reconsider this check if _batched_generator is used for non gradient-based # attribution algorithms if not (inputs[0] * 1).requires_grad: warnings.warn( """It looks like that the attribution for a gradient-based method is computed in a `torch.no_grad` block or perhaps the inputs have no requires_grad.""" ) if internal_batch_size is None: yield inputs, additional_forward_args, target_ind else: for current_total in range(0, num_examples, internal_batch_size): with torch.autograd.set_grad_enabled(True): inputs_splice = _tuple_splice_range( inputs, current_total, current_total + internal_batch_size ) yield inputs_splice, _tuple_splice_range( additional_forward_args, current_total, current_total + internal_batch_size, ), target_ind[ current_total : current_total + internal_batch_size ] if isinstance( target_ind, list ) or ( isinstance(target_ind, torch.Tensor) and target_ind.numel() > 1 ) else target_ind