def gradient_func( forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, additional_forward_args: Any = None, ) -> Tuple[Tensor, ...]: if self.device_ids is None: scattered_inputs = (inputs, ) else: # scatter method does not have a precise enough return type in its # stub, so suppress the type warning. scattered_inputs = scatter( # type:ignore inputs, target_gpus=self.device_ids) scattered_inputs_dict = { scattered_input[0].device: scattered_input for scattered_input in scattered_inputs } with torch.autograd.set_grad_enabled(True): def layer_forward_hook(module, hook_inputs, hook_outputs=None): device = _extract_device(module, hook_inputs, hook_outputs) is_layer_tuple = (isinstance(hook_outputs, tuple) if hook_outputs is not None else isinstance( hook_inputs, tuple)) if is_layer_tuple: return scattered_inputs_dict[device] return scattered_inputs_dict[device][0] hook = None try: if attribute_to_layer_input: hook = self.layer.register_forward_pre_hook( layer_forward_hook) else: hook = self.layer.register_forward_hook( layer_forward_hook) output = _run_forward(self.forward_func, tuple(), target_ind, additional_forward_args) finally: if hook is not None: hook.remove() assert output[0].numel() == 1, ( "Target not provided when necessary, cannot" " take gradient with respect to multiple outputs.") # torch.unbind(forward_out) is a list of scalar tensor tuples and # contains batch_size * #steps elements grads = torch.autograd.grad(torch.unbind(output), inputs) return grads
def gradient_func( forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], target_ind: TargetType = None, additional_forward_args: Any = None, ) -> Tuple[Tensor, ...]: if self.device_ids is None or len(self.device_ids) == 0: scattered_inputs = (inputs, ) else: # scatter method does not have a precise enough return type in its # stub, so suppress the type warning. scattered_inputs = scatter( # type:ignore inputs, target_gpus=self.device_ids) scattered_inputs_dict = { scattered_input[0].device: scattered_input for scattered_input in scattered_inputs } with torch.autograd.set_grad_enabled(True): def layer_forward_hook(module, hook_inputs, hook_outputs=None, layer_idx=0): device = _extract_device(module, hook_inputs, hook_outputs) is_layer_tuple = ( isinstance(hook_outputs, tuple) # hook_outputs is None if attribute_to_layer_input == True if hook_outputs is not None else isinstance( hook_inputs, tuple)) if is_layer_tuple: return scattered_inputs_dict[device][ num_outputs_cumsum[layer_idx]:num_outputs_cumsum[ layer_idx + 1]] return scattered_inputs_dict[device][ num_outputs_cumsum[layer_idx]] hooks = [] try: layers = self.layer if not isinstance(layers, list): layers = [self.layer] for layer_idx, layer in enumerate(layers): hook = None # TODO: # Allow multiple attribute_to_layer_input flags for # each layer, i.e. attribute_to_layer_input[layer_idx] if attribute_to_layer_input: hook = layer.register_forward_pre_hook( functools.partial(layer_forward_hook, layer_idx=layer_idx)) else: hook = layer.register_forward_hook( functools.partial(layer_forward_hook, layer_idx=layer_idx)) hooks.append(hook) output = _run_forward(self.forward_func, tuple(), target_ind, additional_forward_args) finally: for hook in hooks: if hook is not None: hook.remove() assert output[0].numel() == 1, ( "Target not provided when necessary, cannot" " take gradient with respect to multiple outputs.") # torch.unbind(forward_out) is a list of scalar tensor tuples and # contains batch_size * #steps elements grads = torch.autograd.grad(torch.unbind(output), inputs) return grads