コード例 #1
0
ファイル: custom.py プロジェクト: wesselb/lab
def autograd_register(f, s_f):
    """Register a function and its sensitivity for AutoGrad.

    Args:
        f (function): Function to register.
        s_f (function): Sensitivity of `f`.

    Returns:
        function: AutoGrad primitive.
    """
    # Create a primitive for `f`.
    f_primitive = primitive(f)

    # Register the sensitivity.
    def vjp_argnums(nums, y, args, kw_args):
        def vjp(s_y):
            grads = as_tuple(s_f(s_y, y, *args, **kw_args))
            return tuple([grads[i] for i in nums])

        return vjp

    defvjp_argnums(f_primitive, vjp_argnums)

    # Return the AutoGrad primitive.
    return f_primitive
コード例 #2
0
ファイル: linalg.py プロジェクト: abfarr/moo2020
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)


defjvp_argnums(solve_sylvester, _jvp_sylvester)


def _vjp_sylvester(argnums, ans, args, _):
    a, b, q = args

    def vjp(g):
        vjps = []
        q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g)
        if 0 in argnums: vjps.append(-anp.dot(q_vjp, anp.transpose(ans)))
        if 1 in argnums: vjps.append(-anp.dot(anp.transpose(ans), q_vjp))
        if 2 in argnums: vjps.append(q_vjp)
        return tuple(vjps)

    return vjp


defvjp_argnums(solve_sylvester, _vjp_sylvester)
コード例 #3
0
ファイル: integrate.py プロジェクト: HIPS/autograd
            # Run augmented system backwards to the previous observation.
            aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
            aug_ans = odeint(augmented_dynamics, aug_y0,
                             np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs)
            _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])

            # Add gradient from current output.
            vjp_y = vjp_y + g[i - 1, :]

        time_vjp_list.append(vjp_t0)
        vjp_times = np.hstack(time_vjp_list)[::-1]

        return None, vjp_y, vjp_times, unflatten(vjp_args)
    return vjp_all


def argnums_unpack(all_vjp_builder):
    # A generic autograd helper function.  Takes a function that
    # builds vjps for all arguments, and wraps it to return only required vjps.
    def build_selected_vjps(argnums, ans, combined_args, kwargs):
        vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)

        def chosen_vjps(g):  # Returns whichever vjps were asked for.
            all_vjps = vjp_func(g)
            return [all_vjps[argnum] for argnum in argnums]
        return chosen_vjps
    return build_selected_vjps

defvjp_argnums(odeint, argnums_unpack(grad_odeint))
コード例 #4
0
ファイル: linalg.py プロジェクト: HIPS/autograd
def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm jvp not implemented for disp=False"
    return solve_sylvester(ans, ans, dA)
defjvp(sqrtm, _jvp_sqrtm)

def _jvp_sylvester(argnums, dms, ans, args, _):
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)
defjvp_argnums(solve_sylvester, _jvp_sylvester)

def _vjp_sylvester(argnums, ans, args, _):
    a, b, q = args
    def vjp(g):
        vjps = []
        q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g)
        if 0 in argnums: vjps.append(-anp.dot(q_vjp, anp.transpose(ans)))
        if 1 in argnums: vjps.append(-anp.dot(anp.transpose(ans), q_vjp))
        if 2 in argnums: vjps.append(q_vjp)
        return tuple(vjps)
    return vjp
defvjp_argnums(solve_sylvester, _vjp_sylvester)
			vjp_y = vjp_y + g[i - 1, :]

		time_vjp_list.append(vjp_t0)
		vjp_times = np.hstack(time_vjp_list)[::-1]

		return None, vjp_y, vjp_times, unflatten(vjp_args)

	return vjp_all


def grad_argnums_wrapper(all_vjp_builder):
	"""
	A generic autograd helper funciton. Takes a function that
	builds vjps for all arguments, and wraps it to return only required vjps.
	"""
	def build_selected_vjps(argnums, ans, combined_args, kwargs):
		vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)

		def chosen_vjps(g):
			# Return whichever vjps were asked for
			all_vjps = vjp_func(g)
			return [all_vjps[argnum] for argnum in argnums]

		return chosen_vjps

	return build_selected_vjps


if __name__ == '__main__':
	print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all)))
コード例 #6
0
ファイル: tracers.py プロジェクト: muskanmahajan37/autoconj

defvjp(logdet, lambda ans, x: lambda g: _add2d(g) * _T(np.linalg.inv(x)))


@primitive
def add_n(*args):
    return reduce(np.add, args)


def grad_add_n_full(parent_argnums, ans, args, kwargs):
    meta = [np.metadata(args[i]) for i in parent_argnums]
    return lambda g: [unbroadcast(g, m) for m in meta]


defvjp_argnums(add_n, grad_add_n_full)

## debugging


def print_expr(expr, env={}):
    """Return a string with an SSA-like representation of an expression."""
    if isinstance(expr, ConstExpr):
        return str(expr.val)
    elif isinstance(expr, GraphExpr):
        fragment = []
        temp_names = ('temp_{}'.format(i) for i in itertools.count())
        apply_str = '{} = {}({})\n'.format

        def eval_args(node, partial_args):
            args = subvals(node.args, zip(node.parent_argnums, partial_args))
コード例 #7
0
                             **kwargs)
            _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])

            # Add gradient from current output.
            vjp_y = vjp_y + g[i - 1, :]

        time_vjp_list.append(vjp_t0)
        vjp_times = np.hstack(time_vjp_list)[::-1]

        return None, vjp_y, vjp_times, unflatten(vjp_args)

    return vjp_all


def argnums_unpack(all_vjp_builder):
    # A generic autograd helper function.  Takes a function that
    # builds vjps for all arguments, and wraps it to return only required vjps.
    def build_selected_vjps(argnums, ans, combined_args, kwargs):
        vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)

        def chosen_vjps(g):  # Returns whichever vjps were asked for.
            all_vjps = vjp_func(g)
            return [all_vjps[argnum] for argnum in argnums]

        return chosen_vjps

    return build_selected_vjps


defvjp_argnums(odeint, argnums_unpack(grad_odeint))