def grad(func, argnum=None): """Returns the gradient as a callable function of (functions of) QNodes. This is a wrapper around the :mod:`autograd.grad` function. Function arguments with the property ``requires_grad`` set to ``False`` will automatically be excluded from the gradient computation, unless the ``argnum`` keyword argument is passed. Args: func (function): a Python function or QNode that contains a combination of quantum and classical nodes Keyword Args: argnum (int, list(int), None): Which argument(s) to take the gradient with respect to. By default, the arguments themselves are used to determine differentiability, by examining the ``requires_grad`` property. Providing this keyword argument overrides this behaviour, allowing argument differentiability to be defined manually for the returned gradient function. Returns: function: The function that returns the gradient of the input function with respect to the differentiable arguments, or, if specified, the arguments in ``argnum``. """ # pylint: disable=no-value-for-parameter if argnum is not None: # for backwards compatibility with existing code # that manually specifies argnum return _grad(func, argnum) def _gradient_function(*args, **kwargs): """Inspect the arguments for differentiability, and compute the autograd gradient function with required argnums dynamically. This wrapper function is returned to the user instead of autograd.grad, so that we can take into account cases where the user computes the gradient function once, but then calls it with arguments that change in differentiability. """ argnum = [] for idx, arg in enumerate(args): if getattr(arg, "requires_grad", True): argnum.append(idx) return _grad(func, argnum)(*args, **kwargs) return _gradient_function
def _gradient_function(*args, **kwargs): """Inspect the arguments for differentiability, and compute the autograd gradient function with required argnums dynamically. This wrapper function is returned to the user instead of autograd.grad, so that we can take into account cases where the user computes the gradient function once, but then calls it with arguments that change in differentiability. """ argnum = [] for idx, arg in enumerate(args): if getattr(arg, "requires_grad", True): argnum.append(idx) return _grad(func, argnum)(*args, **kwargs)
def grad(func, argnum): """Returns the gradient as a callable function of (functions of) QNodes. This is a wrapper around the :mod:`autograd.grad` functions. Args: func (function): a Python function or QNode that contains a combination of quantum and classical nodes argnum (int or list(int)): which argument(s) to take the gradient with respect to Returns: function: the function that returns the gradient of the input function with respect to the arguments in argnum """ # pylint: disable=no-value-for-parameter return _grad(func, argnum)