Ejemplo n.º 1
0
def make_initializers(args, neg_energy, normalizers, stats_funs):
  stats_vals = [stat_fun(arg) for stat_fun, arg in zip(stats_funs, args)]
  make_nat_init = (lambda i: lambda scale=1.:
                   make_vjp(neg_energy, i)(*stats_vals)[0](scale))
  natural_initializers = map(make_nat_init, range(len(normalizers)))
  make_mean_init = (lambda (i, normalizer): lambda scale=1.:
                   grad(normalizer)(make_vjp(neg_energy, i)(*stats_vals)[0](scale)))
  mean_initializers = map(make_mean_init, enumerate(normalizers))

  return natural_initializers, mean_initializers
Ejemplo n.º 2
0
def rebar_all(params, est_params, noise_u, noise_v, f):
    # Returns objective, gradients, and gradients of variance of gradients.
    func_vals = f(bernoulli_sample(params, noise_u))
    var_vjp, grads = make_vjp(rebar, argnum=1)(params, est_params, noise_u,
                                               noise_v, f)
    d_var_d_est = var_vjp(2 * grads / grads.shape[0])
    return func_vals, grads, d_var_d_est
Ejemplo n.º 3
0
def flatten(value):
    """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or
    scalars inside. Returns 1D numpy array and an unflatten function.
    Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict
    keys are sortable."""
    unflatten, flat_value = make_vjp(_flatten)(value)
    return flat_value, unflatten
Ejemplo n.º 4
0
def fixed_point_vjp(ans, f, a, x0, distance, tol):
    def rev_iter(params):
        a, x_star, x_star_bar = params
        vjp_x, _ = make_vjp(f(a))(x_star)
        vs = vspace(x_star)
        return lambda g: vs.add(vjp_x(g), x_star_bar)
    vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans)
    return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)),
                           vspace(x0).zeros(), distance, tol))
Ejemplo n.º 5
0
 def __init__(self, 
              input_variable: np.ndarray,
              predictions_fn: Callable[[np.ndarray], np.ndarray],
              loss_fn: Callable[[np.ndarray], float],
              damping_factor: float = 1.0, 
              damping_update_factor: float = 2/3,
              update_cond_threshold_low: float = 0.25, # following the least squares book
              update_cond_threshold_high: float = 0.75,
              damping_threshold_low = 1e-7,
              damping_threshold_high = 1e7,
              max_cg_iter: int = 100,
              cg_tol: float = 1e-5,
              squared_loss: bool = True) -> None:
     """
     Paramters:
     input_var: 
     1-d numpy array. Flatten before passing, if necessary.
     
     loss_input_fn: 
     Function that takes in input_var as the only input parameter.
     This should output a 1-d array, that is then passed to loss_fn for the 
     actual loss calculation.  
     Separating the loss calculation into a two step process this way 
     simplifies the second order calculations.
     
     loss_fn:
     Function that takes in the output of loss_input_fn and 
     calculates the singular loss value.
     """
     
     self._input_var = input_variable
     self._predictions_fn = predictions_fn
     self._loss_fn = loss_fn
     
     # Multiplicating factor to update the damping factor at the end of each cycle
     self._damping_factor = damping_factor
     self._damping_update_factor = damping_update_factor
     self._update_cond_threshold_low = update_cond_threshold_low
     self._update_cond_threshold_high =  update_cond_threshold_high
     self._damping_threshold_low = damping_threshold_low
     self._damping_threshold_high = damping_threshold_high
     self._max_cg_iter = max_cg_iter
     self._cg_tol = cg_tol
     self._squared_loss = squared_loss
     
     # variable used for the updates
     self._update_var = np.zeros_like(self._input_var)
     
     self._vjp = ag.make_vjp(self._predictions_fn)
     self._jvp = ag.differential_operators.make_jvp_reversemode(self._predictions_fn)
     
     self._grad = ag.grad(self._loss_fn)
     
     if self._squared_loss:
         self._hjvp = self._jvp
     else:
         self._hjvp = ag.differential_operators.make_hvp(self._loss_fn)
Ejemplo n.º 6
0
def implicit_vjp(f, xstar, params):
    """Computes $\\bar{x}^\\top \\frac{dx^\\star}{d\\theta}$

    Args:
        f (callable): binary callable
        xstar (np.ndarray): fixed-point value
        params (tuple): parameters for the operator f

    Returns:
        np.ndarray: vector-Jacobian product at the fixed-point
    """
    vjp_xstar, _ = make_vjp(f, argnum=0)(xstar, params)
    vjp_params, _ = make_vjp(f, argnum=1)(xstar, params)

    def _vjp(xbar):
        ybar = basic_iterative_solver(vjp_xstar, xbar)
        return vjp_params(ybar)

    return _vjp
Ejemplo n.º 7
0
def unflatten_tracing():
    val = [
        npr.randn(4), [npr.randn(3, 4), 2.5], (), (2.0, [1.0,
                                                         npr.randn(2)])
    ]
    vect, unflatten = flatten(val)

    def f(vect):
        return unflatten(vect)

    flatten2, _ = make_vjp(f)(vect)
    assert np.all(vect == flatten2(val))
Ejemplo n.º 8
0
def view_update(data, view_fun):
    view_vjp, item = make_vjp(view_fun)(data)
    item_vs = vspace(item)

    def update(new_item):
        assert item_vs == vspace(new_item), \
            "Please ensure new_item shape and dtype match the data view."
        diff = view_vjp(
            item_vs.add(new_item, item_vs.scalar_mul(item, -np.uint64(1))))
        return vspace(data).add(data, diff)

    return item, update
Ejemplo n.º 9
0
def vjp(func, x, backend='autograd'):
    """
    Returns a constructor that would generate a function that computes the VJP between its argument and the
    Jacobian of func.
    :param func: Function handle of loss function.
    :param x: List. A list of all arguments to func. The order of arguments must match.
    :return: The returned constructor receives the input of the differentiated function as input, and the function it returns
             receives the (adjoint) vector as input.
    """
    if backend == 'autograd':
        return ag.make_vjp(func, x)
    elif backend == 'pytorch':
        raise NotImplementedError('VJP for Pytorch backend is not implemented yet.')
Ejemplo n.º 10
0
 def __init__(self,
              pot_energy,
              chol_metric,
              grad_pot_energy=None,
              vjp_chol_metric=None):
     super().__init__(pot_energy, grad_pot_energy)
     self._chol_metric = chol_metric
     if vjp_chol_metric is None and autograd_available:
         self._vjp_chol_metric = make_vjp(chol_metric)
     elif vjp_chol_metric is None and not autograd_available:
         raise ValueError('Autograd not available therefore '
                          'vjp_chol_metric must be provided.')
     else:
         self._vjp_chol_metric = vjp_chol_metric
Ejemplo n.º 11
0
def multilinear_representation(log_joint_fun, args, supports):
    expr = _split_einsum_stats(canonicalize(make_expr(log_joint_fun, *args)))

    stats_nodes, keys = [], []
    for name, free_node in expr.free_vars.iteritems():
        nodes = find_sufficient_statistic_nodes(expr, name)
        names = (_summarize_node(node, free_node) for node in nodes)
        key, nodes = zip(*sorted(zip(names, nodes)))
        stats_nodes.append(nodes)
        keys.append(key)

    neg_energy = _make_neg_energy(expr, stats_nodes, supports)

    make_normalizer_fun = (lambda key, support: lambda arg:
                           suff_stat_to_log_normalizer[support][key](*arg))
    normalizers = map(make_normalizer_fun, keys, supports)

    make_stat_fun = (lambda name, nodes: lambda arg: tuple(
        eval_node(node, expr.free_vars, {name: arg}) for node in nodes))
    stats_funs = map(make_stat_fun, expr.free_vars, stats_nodes)

    stats_vals = [stat_fun(arg) for stat_fun, arg in zip(stats_funs, args)]
    make_nat_init = (lambda i: lambda scale=1.: make_vjp(neg_energy, i)
                     (*stats_vals)[0](scale))
    natural_initializers = map(make_nat_init, range(len(normalizers)))
    make_mean_init = (lambda (i, normalizer): lambda scale=1.: grad(normalizer)
                      (make_vjp(neg_energy, i)(*stats_vals)[0](scale)))
    mean_initializers = map(make_mean_init, enumerate(normalizers))

    samplers = [
        suff_stat_to_dist[support][key]
        for key, support in zip(keys, supports)
    ]

    return (neg_energy, normalizers, stats_funs, natural_initializers,
            mean_initializers, samplers)
Ejemplo n.º 12
0
def unflatten_tracing():
    val = [npr.randn(4), [npr.randn(3,4), 2.5], (), (2.0, [1.0, npr.randn(2)])]
    vect, unflatten = flatten(val)
    def f(vect): return unflatten(vect)
    flatten2, _ = make_vjp(f)(vect)
    assert np.all(vect == flatten2(val))
Ejemplo n.º 13
0
 def augmented_dynamics(augmented_state, t, flat_args):
     # Orginal system augmented with vjp_y, vjp_t and vjp_args.
     y, vjp_y, _, _ = unpack(augmented_state)
     vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
     vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
     return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
Ejemplo n.º 14
0
from autograd import make_vjp

import autograd.numpy as np
import autograd.numpy.random as npr

dot_0 = lambda a, b, g: make_vjp(np.dot, argnum=0)(a, b)[0](g)
dot_1 = lambda a, b, g: make_vjp(np.dot, argnum=1)(a, b)[0](g)

dot_0_0 = lambda a, b, g: make_vjp(dot_0, argnum=0)(a, b, g)[0](a)
dot_0_1 = lambda a, b, g: make_vjp(dot_0, argnum=1)(a, b, g)[0](a)
dot_0_2 = lambda a, b, g: make_vjp(dot_0, argnum=2)(a, b, g)[0](a)

dot_1_0 = lambda a, b, g: make_vjp(dot_1, argnum=0)(a, b, g)[0](b)
dot_1_1 = lambda a, b, g: make_vjp(dot_1, argnum=1)(a, b, g)[0](b)
dot_1_2 = lambda a, b, g: make_vjp(dot_1, argnum=2)(a, b, g)[0](b)

a = npr.randn(2, 3, 4, 5)
b = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)

def time_dot_0():
    dot_0(a, b, g)

def time_dot_1():
    dot_1(a, b, g)

def time_dot_0_0():
    dot_0_0(a, b, g)

def time_dot_0_1():
    dot_0_1(a, b, g)
Ejemplo n.º 15
0
from autograd import make_vjp

import autograd.numpy as np
import autograd.numpy.random as npr

dot_0 = lambda a, b, g: make_vjp(np.dot, argnum=0)(a, b)[0](g)
dot_1 = lambda a, b, g: make_vjp(np.dot, argnum=1)(a, b)[0](g)

dot_0_0 = lambda a, b, g: make_vjp(dot_0, argnum=0)(a, b, g)[0](a)
dot_0_1 = lambda a, b, g: make_vjp(dot_0, argnum=1)(a, b, g)[0](a)
dot_0_2 = lambda a, b, g: make_vjp(dot_0, argnum=2)(a, b, g)[0](a)

dot_1_0 = lambda a, b, g: make_vjp(dot_1, argnum=0)(a, b, g)[0](b)
dot_1_1 = lambda a, b, g: make_vjp(dot_1, argnum=1)(a, b, g)[0](b)
dot_1_2 = lambda a, b, g: make_vjp(dot_1, argnum=2)(a, b, g)[0](b)

a = npr.randn(2, 3, 4, 5)
b = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)


def time_dot_0():
    dot_0(a, b, g)


def time_dot_1():
    dot_1(a, b, g)


def time_dot_0_0():
    dot_0_0(a, b, g)
	def augmented_dynamics(augmented_state, t, flat_args):
		# Original system augemented with vjp_y, vjp_t and vjp_args
		y, vjp_y, _, _ = unpack(augmented_state)
		vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
		vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
		return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
Ejemplo n.º 17
0
    def __init__(self,
                 input_variable: np.ndarray,
                 predictions_fn: Callable[[np.ndarray], np.ndarray],
                 loss_fn: Callable[[np.ndarray], float],
                 damping_factor: float = 1.0,
                 damping_update_factor: float = 0.999,
                 damping_update_frequency: int = 5,
                 update_cond_threshold_low: float = 0.5,
                 update_cond_threshold_high: float = 1.5,
                 damping_threshold_low: float = 1e-7,
                 damping_threshold_high: float = 1e7,
                 alpha_init: float = 1.0,
                 squared_loss: bool = True) -> None:
        """
        Paramters:
        input_var: 
        1-d numpy array. Flatten before passing, if necessary.
        
        loss_input_fn: 
        Function that takes in input_var as the only input parameter.
        This should output a 1-d array, that is then passed to loss_fn for the 
        actual loss calculation.  
        Separating the loss calculation into a two step process this way 
        simplifies the second order calculations.
        
        loss_fn:
        Function that takes in the output of loss_input_fn and 
        calculates the singular loss value.
        
        from 
        The alpha should generally just be 1.0 and doesn't change. 
        The beta and rho values are updated at each cycle, so there is no intial value."""

        self._input_var = input_variable
        self._predictions_fn = predictions_fn
        self._loss_fn = loss_fn

        # Multiplicating factor to update the damping factor at the end of each cycle
        self._damping_factor = damping_factor
        self._damping_update_factor = damping_update_factor
        self._damping_update_frequency = damping_update_frequency
        self._update_cond_threshold_low = update_cond_threshold_low
        self._update_cond_threshold_high = update_cond_threshold_high
        self._damping_threshold_low = damping_threshold_low
        self._damping_threshold_high = damping_threshold_high
        self._squared_loss = squared_loss
        self._alpha = alpha_init

        self._z = np.zeros_like(self._input_var)
        self._iteration = 0

        self._vjp = ag.make_vjp(self._predictions_fn)
        self._jvp = ag.differential_operators.make_jvp_reversemode(
            self._predictions_fn)

        self._grad = ag.grad(self._loss_fn)

        if self._squared_loss:
            self._hjvp = self._jvp
        else:
            self._hjvp = ag.differential_operators.make_hvp(self._loss_fn)

        self._iteration = 0
Ejemplo n.º 18
0
 def rev_iter(params):
     a, x_star, x_star_bar = params
     vjp_x, _ = make_vjp(f(a))(x_star)
     vs = vspace(x_star)
     return lambda g: vs.add(vjp_x(g), x_star_bar)
Ejemplo n.º 19
0
 def grad_h(x):
   f_vjp, _ = make_vjp(f)(getval(x))
   return f_vjp(grad(g)(f(x)))
Ejemplo n.º 20
0
 def ggnvp_maker(x):
   return make_vjp(grad_h)(x)[0]
Ejemplo n.º 21
0
 def improve_approx(g, k):
   return lambda x, v: make_vjp(g)(x, v)[0](v) + f(x) / factorial(k)