コード例 #1
0
        def gradient_func(
            forward_fn: Callable,
            inputs: Union[Tensor, Tuple[Tensor, ...]],
            target_ind: TargetType = None,
            additional_forward_args: Any = None,
        ) -> Tuple[Tensor, ...]:
            if self.device_ids is None:
                scattered_inputs = (inputs, )
            else:
                # scatter method does not have a precise enough return type in its
                # stub, so suppress the type warning.
                scattered_inputs = scatter(  # type:ignore
                    inputs, target_gpus=self.device_ids)

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(module, hook_inputs, hook_outputs=None):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    is_layer_tuple = (isinstance(hook_outputs, tuple) if
                                      hook_outputs is not None else isinstance(
                                          hook_inputs, tuple))
                    if is_layer_tuple:
                        return scattered_inputs_dict[device]
                    return scattered_inputs_dict[device][0]

                hook = None
                try:
                    if attribute_to_layer_input:
                        hook = self.layer.register_forward_pre_hook(
                            layer_forward_hook)
                    else:
                        hook = self.layer.register_forward_hook(
                            layer_forward_hook)

                    output = _run_forward(self.forward_func, tuple(),
                                          target_ind, additional_forward_args)
                finally:
                    if hook is not None:
                        hook.remove()

                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs.")
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads
コード例 #2
0
        def gradient_func(
            forward_fn: Callable,
            inputs: Union[Tensor, Tuple[Tensor, ...]],
            target_ind: TargetType = None,
            additional_forward_args: Any = None,
        ) -> Tuple[Tensor, ...]:
            if self.device_ids is None or len(self.device_ids) == 0:
                scattered_inputs = (inputs, )
            else:
                # scatter method does not have a precise enough return type in its
                # stub, so suppress the type warning.
                scattered_inputs = scatter(  # type:ignore
                    inputs, target_gpus=self.device_ids)

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(module,
                                       hook_inputs,
                                       hook_outputs=None,
                                       layer_idx=0):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    is_layer_tuple = (
                        isinstance(hook_outputs, tuple)
                        # hook_outputs is None if attribute_to_layer_input == True
                        if hook_outputs is not None else isinstance(
                            hook_inputs, tuple))

                    if is_layer_tuple:
                        return scattered_inputs_dict[device][
                            num_outputs_cumsum[layer_idx]:num_outputs_cumsum[
                                layer_idx + 1]]

                    return scattered_inputs_dict[device][
                        num_outputs_cumsum[layer_idx]]

                hooks = []
                try:

                    layers = self.layer
                    if not isinstance(layers, list):
                        layers = [self.layer]

                    for layer_idx, layer in enumerate(layers):
                        hook = None
                        # TODO:
                        # Allow multiple attribute_to_layer_input flags for
                        # each layer, i.e. attribute_to_layer_input[layer_idx]
                        if attribute_to_layer_input:
                            hook = layer.register_forward_pre_hook(
                                functools.partial(layer_forward_hook,
                                                  layer_idx=layer_idx))
                        else:
                            hook = layer.register_forward_hook(
                                functools.partial(layer_forward_hook,
                                                  layer_idx=layer_idx))

                        hooks.append(hook)

                    output = _run_forward(self.forward_func, tuple(),
                                          target_ind, additional_forward_args)
                finally:
                    for hook in hooks:
                        if hook is not None:
                            hook.remove()

                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs.")
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads