def test_grad_closure_with_vmap(self): # https://github.com/google/jax/issues/2718 @jax.jit def experiment(x): def model(y, t): return -x * y history = odeint(model, 1., np.arange(0, 10, 0.1)) return history[-1] gradfun = jax.value_and_grad(experiment) t = np.arange(0., 1., 0.01) h, g = jax.vmap(gradfun)(t) # doesn't crash ans = h[11], g[11] expected_h = experiment(t[11]) expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5 expected = expected_h, expected_g self.assertAllClose(ans, expected, check_dtypes=False, atol=1e-2, rtol=1e-2)
def train(network_def, target_params, optimizer, states, actions, next_states, rewards, terminals, cumulative_gamma): """Run the training step.""" online_params = optimizer.target def loss_fn(params, target): def q_online(state): return network_def.apply(params, state) q_values = jax.vmap(q_online)(states).q_values q_values = jnp.squeeze(q_values) replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions) loss = jnp.mean(jax.vmap(huber_loss)(target, replay_chosen_q)) return loss def q_target(state): return network_def.apply(target_params, state) target = target_q(q_target, next_states, rewards, terminals, cumulative_gamma) grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(online_params, target) optimizer = optimizer.apply_gradient(grad) return optimizer, loss
def _wrap_f(vs, names, f, jit, _convert): # Differentiable assignments will overwrite the variables, so make a copy. vs_copy = vs.copy() # Keep track of function evaluations. f_evals = [] def f_vectorised(x): vs_copy.set_latent_vector(x, *names, differentiable=True) return f(vs_copy) f_value_and_grad = value_and_grad(f_vectorised) if jit: f_value_and_grad = B.jit(f_value_and_grad) def f_wrapped(x): x = B.cast(vs.dtype, x) # Compute objective function value and gradient. try: obj_value, grad = f_value_and_grad(x) except Exception as e: return exception(x, e) # Perform requested conversion. obj_value, grad = _convert(obj_value, grad) # The gradient may not have the right memory layout, which sometimes cannot # be adjusted. We therefore make a copy, which can always be freely manipulated. grad = np.array(grad) f_evals.append(obj_value) return obj_value, grad return f_evals, f_wrapped
def train_step(model: models.NerfModel, rng_key: Callable[[int], jnp.ndarray], state, batch: Dict[str, Any], scalar_params: ScalarParams, use_elastic_loss: bool = False, elastic_reduce_method: str = 'median', use_background_loss: bool = False): """One optimization step. Args: model: the model module to evaluate. rng_key: The random number generator. state: model_utils.TrainState, state of model and optimizer. batch: dict. A mini-batch of data for training. scalar_params: scalar-valued parameters. use_elastic_loss: is True use the elastic regularization loss. elastic_reduce_method: which method to use to reduce the samples for the elastic loss. 'median' selects the median depth point sample while 'weight' computes a weighted sum using the density weights. use_background_loss: if True use the background regularization loss. Returns: new_state: model_utils.TrainState, new training state. stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. """ rng_key, fine_key, coarse_key, reg_key = random.split(rng_key, 4) # pylint: disable=unused-argument def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean() stats = { 'loss/rgb': rgb_loss, } loss = rgb_loss if use_elastic_loss: v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss))) weights = lax.stop_gradient(model_out['weights']) jacobian = model_out['warp_jacobian'] # Pick the median point Jacobian. if elastic_reduce_method == 'median': depth_indices = model_utils.compute_depth_index(weights) jacobian = jnp.take_along_axis( # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. jacobian, depth_indices[..., None, None, None], axis=-3) # Compute loss using Jacobian. elastic_loss, elastic_residual = v_elastic_fn(jacobian) # Multiply weight if weighting by density. if elastic_reduce_method == 'weight': elastic_loss = weights * elastic_loss elastic_loss = elastic_loss.sum(axis=-1).mean() stats['loss/elastic'] = elastic_loss stats['residual/elastic'] = jnp.mean(elastic_residual) loss += scalar_params.elastic_loss_weight * elastic_loss if 'warp_jacobian' in model_out: jacobian = model_out['warp_jacobian'] jacobian_det = jnp.linalg.det(jacobian) jacobian_div = utils.jacobian_to_div(jacobian) jacobian_curl = utils.jacobian_to_curl(jacobian) stats['metric/jacobian_det'] = jnp.mean(jacobian_det) stats['metric/jacobian_div'] = jnp.mean(jacobian_div) stats['metric/jacobian_curl'] = jnp.mean( jnp.linalg.norm(jacobian_curl, axis=-1)) stats['loss/total'] = loss stats['metric/psnr'] = utils.compute_psnr(rgb_loss) return loss, stats def _loss_fn(params): ret = model.apply({'params': params['model']}, batch, warp_alpha=state.warp_alpha, rngs={ 'fine': fine_key, 'coarse': coarse_key }) losses = {} stats = {} if 'fine' in ret: losses['fine'], stats['fine'] = _compute_loss_and_stats( params, ret['fine']) if 'coarse' in ret: losses['coarse'], stats['coarse'] = _compute_loss_and_stats( params, ret['coarse'], use_elastic_loss=use_elastic_loss) if use_background_loss: background_loss = compute_background_loss( model, state=state, params=params['model'], key=reg_key, points=batch['background_points'], noise_std=scalar_params.background_noise_std) background_loss = background_loss.mean() losses['background'] = ( scalar_params.background_loss_weight * background_loss) stats['background_loss'] = background_loss return sum(losses.values()), stats optimizer = state.optimizer grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (_, stats), grad = grad_fn(optimizer.target) grad = jax.lax.pmean(grad, axis_name='batch') stats = jax.lax.pmean(stats, axis_name='batch') new_optimizer = optimizer.apply_gradient( grad, learning_rate=scalar_params.learning_rate) new_state = state.replace(optimizer=new_optimizer) return new_state, stats, rng_key
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False): r"""Creates a function which evaluates both ``fun`` and the grad of ``fun``. NOTE: You only need this in a very specific case that you want to take a gradient **inside** a :func:`transform`\ ed function and the function you are differentiating uses :func:`set_state`. For example: >>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", jnp.sum(x)) ... return x ** 2 >>> def f(x): ... m = MyModule() ... y, g = hk.value_and_grad(m)(x) ... return y, g >>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> _ = jax.jit(f.init)(None, x) Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. Returns: A function with the same arguments as ``fun`` that evaluates both ``fun`` and the gradient of ``fun`` and returns them as a pair (a two-element tuple). If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. """ if not base.inside_transform(): raise ValueError( "hk.grad() should not be used outside of hk.transform(). " "Use jax.grad() instead.") @functools.wraps(fun) def stateful_fun(*args, **kwargs): state_in = kwargs.pop("hk_state") with temporary_internal_state(state_in): out = fun(*args, **kwargs) out, aux = (out if has_aux else (out, None)) state_out = difference(state_in, internal_state()) return out, (aux, state_out) grad_fun = jax.value_and_grad(stateful_fun, argnums=argnums, has_aux=True, holomorphic=holomorphic) @functools.wraps(grad_fun) def wrapper(*args, **kwargs): kwargs["hk_state"] = internal_state() (value, (aux, hk_state)), grads = grad_fun(*args, **kwargs) update_internal_state(hk_state) if has_aux: return (value, aux), grads else: return value, grads return wrapper
def fista_step(data, loss_and_prox_op, model_param, options): """Fista optimization step for solving regularized problem. Args: data: A tuple of inputs and labels passed to the loss function. loss_and_prox_op: Tuple of (loss_f, prox_g) loss_f is the loss function that takes in model_param, inputs, and labels. prox_g is the proximity operator for g. model_param: Current model parameters to be passed to loss_f. options: A dictionary of optimizer specific hyper-parameters. Returns: Updated model parameters and updated step size. """ options = dict(options) step_size = options.get('step_size', 1.0) acceleration = options.get('acceleration', True) t = options.get('t', 1.0) verbose = options.get('verbose', False) reuse_last_step = options.get('reuse_last_step', False) loss_f, prox_g = loss_and_prox_op inputs, labels = data[0], data[1] fun_f = lambda param: loss_f(param, inputs, labels) value_and_grad_f = jax.value_and_grad(fun_f) x, unravel_fn = ravel_pytree(model_param) y = options.get('y', x) value_f, grad_f = value_and_grad_f(unravel_fn(y)) grad_f, unravel_fn = ravel_pytree(grad_f) def next_candidate(step_size): return prox_g(y - grad_f * step_size, step_size) def stop_cond(step_size, next_iter): diff = next_iter - y sqdist = jnp.sum(diff**2) # We do not compute the non-smooth term (g in the paper) # as it cancels out from value_F and value_Q. value_bigf = fun_f(next_iter) value_bigq = value_f + jnp.sum( diff * grad_f) + 0.5 / step_size * sqdist return value_bigf <= value_bigq x_old = x step_size, x = backtracking(next_candidate, stop_cond, step_size, options) # Acceleration. if acceleration: t_next = (1 + jnp.sqrt(1 + 4 * t**2)) / 2. y = x + (t - 1) / t_next * (x - x_old) t = t_next options['y'] = y options['t'] = t else: y = x if reuse_last_step: options['step_size'] = step_size if verbose: logging.info('Step size: %f', step_size) return unravel_fn(x), options
def __init__(self, f, argnums=0): self.f = f self.argnums = argnums self.vg = jax.value_and_grad(self.f, argnums) self.hess = hessian(self.f)
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng_key, init_step_size): """ Finds a reasonable step size by tuning `init_step_size`. This function is used to avoid working with a too large or too small step size in HMC. **References:** 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, Andrew Gelman :param potential_fn: A callable to compute potential energy. :param kinetic_fn: A callable to compute kinetic energy. :param momentum_generator: A generator to get a random momentum variable. :param inverse_mass_matrix: Inverse of mass matrix. :param position: Current position of the particle. :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness. :param float init_step_size: Initial step size to be tuned. :return: a reasonable value for step size. :rtype: float """ # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. target_accept_prob = np.log(0.8) _, vv_update = velocity_verlet(potential_fn, kinetic_fn) z = position potential_energy, z_grad = value_and_grad(potential_fn)(z) finfo = np.finfo(get_dtype(init_step_size)) def _body_fn(state): step_size, _, direction, rng_key = state rng_key, rng_key_momentum = random.split(rng_key) # scale step_size: increase 2x or decrease 2x depends on direction; # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN`, which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). step_size = (2.0**direction) * step_size r = momentum_generator(inverse_mass_matrix, rng_key_momentum) _, r_new, potential_energy_new, _ = vv_update( step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad)) energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy energy_new = kinetic_fn(inverse_mass_matrix, r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = np.where(target_accept_prob < -delta_energy, 1, -1) return step_size, direction, direction_new, rng_key def _cond_fn(state): step_size, last_direction, direction, _ = state # condition to run only if step_size is not too small or we are not decreasing step_size not_small_step_size_cond = (step_size > finfo.tiny) | (direction >= 0) # condition to run only if step_size is not too large or we are not increasing step_size not_large_step_size_cond = (step_size < finfo.max) | (direction <= 0) not_extreme_cond = not_small_step_size_cond & not_large_step_size_cond return not_extreme_cond & ((last_direction == 0) | (direction == last_direction)) step_size, _, _, _ = while_loop(_cond_fn, _body_fn, (init_step_size, 0, 0, rng_key)) return step_size
def train_step(model, rng, state, batch, lr): """One optimization step. Args: model: The linen model. rng: jnp.ndarray, random number generator. state: utils.TrainState, state of the model/optimizer. batch: dict, a mini-batch of data for training. lr: float, real-time learning rate. Returns: new_state: utils.TrainState, new training state. stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)]. rng: jnp.ndarray, updated random number generator. """ rng, key_0, key_1 = random.split(rng, 3) def loss_fn(variables): rays = batch["rays"] ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) if len(ret) > 1: # If there are both coarse and fine predictions, we compute the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) else: loss_c = 0. psnr_c = 0. def tree_sum_fn(fn): return jax.tree_util.tree_reduce(lambda x, y: x + fn(y), variables, initializer=0) weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) / tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2) return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats (_, stats), grad = (jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target)) grad = jax.lax.pmean(grad, axis_name="batch") stats = jax.lax.pmean(stats, axis_name="batch") # Clip the gradient by value. if FLAGS.grad_max_val > 0: clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val ) grad = jax.tree_util.tree_map(clip_fn, grad) # Clip the (possibly value-clipped) gradient by norm. if FLAGS.grad_max_norm > 0: grad_norm = jnp.sqrt( jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), grad, initializer=0)) mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm)) grad = jax.tree_util.tree_map(lambda z: mult * z, grad) new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr) new_state = state.replace(optimizer=new_optimizer) return new_state, stats, rng
def train_step(i, opt_state, batch): """Train for a single step.""" params = get_params(opt_state) value_and_grad_fn = jax.value_and_grad(loss_fn) loss, grad = value_and_grad_fn(params, batch) return opt_update(i, grad, opt_state), loss
def fmin_bfgs(func, x0, args=(), options=None): """ The BFGS algorithm from Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 136-143 Notes: We utilise boolean arithmetic to avoid jax.cond calls which don't work on accelerators. A side effect is that we perform more gradient evaluations than scipy's BFGS func: callable Function of the form f(x) where x is a flat ndarray and returns a real scalar. The function should be composed of operations with vjp defined. If func is jittable then fmin_bfgs is jittable. If func is not jittable, then _nojit should be set to True. x0: ndarray initial variable args: tuple, optional Extra arguments to pass to func as func(x,*args) options: Optional dict of parameters maxiter: int Maximum number of evaluations norm: float Order of norm for convergence check. Default inf. gtol: flat Terminates minimization when |grad|_norm < g_tol ls_maxiter: int Maximum number of linesearch iterations Returns: BFGSResults """ if options is None: options = dict() maxiter: Optional[int] = options.get('maxiter', None) norm: float = options.get('norm', jnp.inf) gtol: float = options.get('gtol', 1e-5) ls_maxiter: int = options.get('ls_maxiter', 10) state = BFGSResults(converged=False, failed=False, k=0, nfev=0, ngev=0, nhev=0, x_k=x0, f_k=None, g_k=None, H_k=None, status=None, ls_status=jnp.array(0)) if maxiter is None: maxiter = jnp.size(x0) * 200 d = x0.shape[0] initial_H = jnp.eye(d) initial_H = options.get('hess_inv', initial_H) def func_with_args(x): return func(x, *args) value_and_grad = jax.value_and_grad(func_with_args) f_0, g_0 = value_and_grad(x0) state = state._replace(f_k=f_0, g_k=g_0, H_k=initial_H, nfev=state.nfev + 1, ngev=state.ngev + 1, converged=jnp.linalg.norm(g_0, ord=norm) < gtol) def body(state): p_k = -(state.H_k @ state.g_k) line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k, maxiter=ls_maxiter) state = state._replace(nfev=state.nfev + line_search_results.nfev, ngev=state.ngev + line_search_results.ngev, failed=line_search_results.failed, ls_status=line_search_results.status) s_k = line_search_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = line_search_results.f_k g_kp1 = line_search_results.g_k # print(g_kp1) y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(y_k @ s_k) sy_k = s_k[:, None] * y_k[None, :] w = jnp.eye(d) - rho_k * sy_k H_kp1 = jnp.where(jnp.isfinite(rho_k), jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol state = state._replace(converged=converged, k=state.k + 1, x_k=x_kp1, f_k=f_kp1, g_k=g_kp1, H_k=H_kp1 ) return state state = while_loop( lambda state: (~ state.converged) & (~state.failed) & (state.k < maxiter), body, state) state = state._replace(status=jnp.where(state.converged, jnp.array(0),#converged jnp.where(state.k == maxiter, jnp.array(1),#max iters reached jnp.where(state.failed, jnp.array(2)+state.ls_status,#ls failed (+ reason) jnp.array(-1)))))#undefined return state
from .. import get_backend, default_backend from ..tensor.common import _TensorViewer from .autodiff import AutoDiffOptimizerMixin import jax def _final_objective(pars, data, fixed_vals, model, objective, fixed_idx, variable_idx): tensorlib, _ = get_backend() tv = _TensorViewer([fixed_idx, variable_idx]) pars = tensorlib.astensor(pars) constrained_pars = tv.stitch([fixed_vals, pars]) return objective(constrained_pars, data, model)[0] _jitted_objective_and_grad = jax.jit(jax.value_and_grad(_final_objective), static_argnums=(3, 4, 5, 6)) class jax_optimizer(AutoDiffOptimizerMixin): """JAX Optimizer Backend.""" def setup_minimize(self, objective, data, pdf, init_pars, par_bounds, fixed_vals=None): """ Prepare Minimization for AutoDiff-Optimizer.
def value_and_grad_f(*args, **kwargs): out_shape = eval_shape(fun, *args, has_aux=has_aux, **kwargs) _args_iterable = (argnums, ) if isinstance(argnums, int) else argnums # only check if derivable arguments are complex if tree_leaf_iscomplex([args[i] for i in _args_iterable]): if is_complex(out_shape): # C -> C return jax.value_and_grad( fun, argnums=argnums, has_aux=has_aux, allow_int=allow_int, holomorphic=True, )(*args, **kwargs) else: # C -> R raise RuntimeError( "C->R function detected, but not supported.") else: if is_complex(out_shape): # R -> C def grad_rc(*args, **kwargs): if has_aux: def real_fun(*args, **kwargs): val, aux = fun(*args, **kwargs) return val.real, aux def imag_fun(*args, **kwargs): val, aux = fun(*args, **kwargs) return val.imag, aux out_r, grad_r, aux = jax.value_and_grad( real_fun, argnums=argnums, has_aux=True, allow_int=allow_int)(*args, **kwargs) out_j, grad_j, _ = jax.value_and_grad( imag_fun, argnums=argnums, has_aux=True, allow_int=allow_int)(*args, **kwargs) else: real_fun = lambda *args, **kwargs: fun( *args, **kwargs).real imag_fun = lambda *args, **kwargs: fun( *args, **kwargs).imag out_r, grad_r = jax.value_and_grad( real_fun, argnums=argnums, has_aux=False, allow_int=allow_int, )(*args, **kwargs) out_j, grad_j = jax.value_and_grad( imag_fun, argnums=argnums, has_aux=False, allow_int=allow_int, )(*args, **kwargs) out = out_r + 1j * out_j grad = jax.tree_map(lambda re, im: re + 1j * im, grad_r, grad_j) if has_aux: return out, grad, aux else: return out, grad return grad_rc(*args, **kwargs) else: # R -> R return jax.value_and_grad(fun, argnums=argnums, has_aux=has_aux, allow_int=allow_int)(*args, **kwargs)
def test_freeze_feature_encoder(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) input_values = inputs_dict["input_values"] attention_mask = inputs_dict["attention_mask"] model = FlaxWav2Vec2ForPreTraining(config) params = model.params # dummy loss function def compute_loss(params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8): outputs = model( input_values, attention_mask=attention_mask, freeze_feature_encoder=freeze_feature_encoder, params=params, ) # compute cosine similarity of projected and projected_quantized states cosine_sim = optax.cosine_similarity( outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon) loss = cosine_sim.sum() return loss, outputs.to_tuple() # transform the loss function to get the gradients grad_fn = jax.value_and_grad(compute_loss, has_aux=True) # compute loss, outputs and gradients for unfrozen model (loss, outputs), grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False) # compare to loss, outputs and gradients for frozen model (loss_frozen, outputs_frozen), grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True) # ensure that the outputs and losses remain precisely equal for output, output_frozen in zip(outputs, outputs_frozen): self.assertTrue((output == output_frozen).all()) self.assertEqual(loss, loss_frozen) grads = flatten_dict(grads) grads_frozen = flatten_dict(grads_frozen) # ensure that the dicts of gradients contain the same keys self.assertEqual(grads.keys(), grads_frozen.keys()) # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k) feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k) for feature_extractor_grad, feature_extractor_grad_frozen in zip( feature_extractor_grads, feature_extractor_grads_frozen): self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) self.assertTrue((feature_extractor_grad > 0.0).any()) # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) for grad, grad_frozen in zip(grads, grads_frozen): self.assertTrue((grad == grad_frozen).all())
def find_connected_curve(w1, w2, loss_fn, n_bends=3, train_steps=100, lr=0.05, scheduler_name='constant', batch_size=4, log_every=1, seed=43, use_jit=False): """ Find a curve connecting two local minimas Args: w1: jnp.array, the first local minima w2: jnp.array, the second local minima loss_fn: callable, n_bends: int, number of bends between endpoints train_steps: int, number of training steps lr: float, initial learning rate scheduler_name: str, scheduler name batch_size: int, number time step samples log_every: int, seed: int, random seed use_jit: bool, whether to use jit-compilation of gradient function. Returns: curve function """ print('Find a curve connecting two local minimas') curve = get_beizer_curve(n_bends + 2) # includes the endpoints def curve_loss(t_, params_): # t_ should be (m,) vector. params_ = jnp.vstack([w1, params_, w2]) c = curve(t_, params_) loss_sum = 0 for c_ in c: loss_sum += loss_fn(c_) return loss_sum / float(len(t_)) # differentiate the 2nd argument print('Get the gradient of the loss function') grad_loss_fn = jax.value_and_grad(curve_loss, argnums=1) if use_jit: print('Use jit compilation') grad_loss_fn = jax.jit(grad_loss_fn) start_step = 0 history = {'loss': [], 'grad': [], 'params': []} min_loss = float('inf') scheduler = qnnops.get_scheduler(lr, train_steps, name=scheduler_name) init_fun, update_fun, get_params = qnnops.get_optimizer( 'adam', None, scheduler) # Pick evenly divided points from the line segment between w1 and w2. alpha = jnp.linspace(0, 1, n_bends + 2) init_bends = jnp.vstack( [alpha[i] * w1 + (1 - alpha[i]) * w2 for i in range(1, n_bends + 1)]) print('State initialization') optimizer_state = init_fun(init_bends) params = get_params(optimizer_state) rng = jax.random.PRNGKey(seed) for step in range(start_step, train_steps): rng, key = jax.random.split(rng) # Loss = E[ L(\phi_\theta(t))] where t ~ Unif(0, 1) t = jax.random.uniform(key, (batch_size, )) loss, grad = grad_loss_fn(t, params) optimizer_state = update_fun(step, grad, optimizer_state) params = get_params(optimizer_state) history['loss'].append(loss.item()) # scalar? history['grad'].append(onp.array(grad)) history['params'].append(onp.array(params)) if loss < min_loss: min_loss = loss save_checkpoints('best', step, w1, w2, params, history, optimizer_state) if step % log_every == 0: grad_norm = jnp.linalg.norm(grad).item() logging_output = OrderedDict(loss=loss.item(), lr=scheduler(step), grad_norm=grad_norm) logging_output['min_loss'] = min_loss.item() expmgr.log(step, logging_output) save_checkpoints('last', step, w1, w2, params, history, optimizer_state) del loss, grad gc.collect()
def optimise(epoch, m, s, params, xs, ys): lossʹ, grads = jax.value_and_grad(loss)(params, xs, ys) m, s, params = jb.adabeliefʹ(epoch, grads, m, s, params) return m, s, params, lossʹ
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type): """Penalize by ||dL/dx||.""" loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0) loss, dloss_dw = loss_and_dloss_dw_f(w, x, y) reg = norm_f(dloss_dw, norm_type) return loss + reg_coeff * reg
def fun(flat_nb_params): v, g = value_and_grad(flat_loss)(flat_nb_params) return float(v), np.array(g)
return cql_term def total_critic_loss(q_params, policy_params, target_q_params, alpha, cql_alpha, transitions, key): critic_loss_term = critic_loss(q_params, policy_params, target_q_params, alpha, transitions, key) cql_term = cql_loss(q_params, policy_params, transitions, key) total = critic_loss_term + cql_alpha * cql_term return total, { 'critic_loss': critic_loss_term, 'cql_loss': cql_term } critic_grad = jax.value_and_grad(total_critic_loss, has_aux=True) def critic_update_step(q_params, target_q_params, optim_state, policy_params, alpha, cql_alpha, transitions, key): total_critic_loss_and_aux, critic_grads = critic_grad( q_params, policy_params, target_q_params, alpha, cql_alpha, transitions, key) critic_grads = jax.lax.pmean(critic_grads, 'devices') # Apply critic gradients critic_update, optim_state = q_optimizer.update( critic_grads, optim_state) q_params = optax.apply_updates(q_params, critic_update) target_q_params = jax.tree_multimap( lambda x, y: x * (1 - tau) + y * tau, target_q_params,
def unroll(x_init, theta, t_current, T, K): L_sum = 0.0 initial_state = (L_sum, x_init, theta, t_current, T, K) iterations = jax.lax.iota( jnp.int32, K ) # This is like jnp.arange(0, K), I don't think iota works for different starting points than 0... state, outputs = jax.lax.scan(update, initial_state, iterations) (L, x_current, theta, t_current, T, K) = state x_trajectory = jnp.stack(outputs).T return L, (x_current, x_trajectory, t_current) grad_unroll = jax.jit(jax.grad(unroll, argnums=1, has_aux=True), static_argnums=(3, 4)) loss_and_grad_unroll = jax.jit(jax.value_and_grad(unroll, argnums=1, has_aux=True), static_argnums=(3, 4)) # Functions for RTRL/UORO # ---------------------------------------------------------------------------------- @jax.jit def single_step(theta, x, t_current, T): g = loss_grad(x) lr = jnp.exp(theta[0]) * (T - t_current) / T + jnp.exp( theta[1]) * t_current / T x_new = x - lr * g return x_new
def optimize_params(c, incidence_scenarios, parameterizer, loss_fn=None, optimization_params=None, verbose=True): """Modifies c in place subject to parameterizer to minimize a loss function. Picks params and runs parameterizer.xr_apply_params(c, params) to minimize loss_fn on the control arm event curves given incidence_scenarios. The final optimized params are stored in c['final_params']. Args: c: xr.Dataset specifying the trial with all data_vars required to call sim.recruitment and sim.control_arm_events. incidence_scenarios: xr.DataArray of shape [scenario, location, time], the forecast incidence. parameterizer: a Parameterizer specifying what the trial planner can control. loss_fn: optional function which takes trajectories of control arm events (an jnp.array of shape [scenario, time]) and returns a jnp.array of losses of shape [scenario]. Defaults to negative_mean_successiness. optimization_params: optional dict of stuff related to how to do the optimization. verbose: if True, print some stuff. """ # Run non-differentiable simulation once, because they do more error checking. participants = sim.recruitment(c) sim.control_arm_events(c, participants, incidence_scenarios) c_ = sim.JaxDataset(c) incidence_scenarios_ = jnp.asarray( incidence_scenarios.transpose('scenario', 'location', 'time')) _, recruitment_fn, _ = sim.RECRUITMENT_REGISTRY[sim.get_recruitment_type( c)] historical_events = c.historical_control_arm_events if 'location' in historical_events.dims: historical_events = c.historical_control_arm_events.sum('location') historical_events_ = jnp.array(historical_events.values) if loss_fn is None: loss_fn = negative_mean_successiness(c) def loss(params): parameterizer.apply_params(c_, params) participants_ = recruitment_fn(c_) control_arm_events_ = sim.differentiable_control_arm_events( c_, participants_, incidence_scenarios_) historical_ = jnp.broadcast_to( historical_events_, control_arm_events_.shape[:1] + historical_events_.shape) control_arm_events_ = jnp.concatenate( [historical_, control_arm_events_], axis=1) return loss_fn(control_arm_events_).mean() if optimization_params is None: optimization_params = dict() min_steps = optimization_params.get('min_steps', 40) max_steps = optimization_params.get('max_steps', 200) epsilon = optimization_params.get('epsilon', 1e-3) smoothing_window = optimization_params.get('smoothing_window', 20) learning_rate = optimization_params.get('learning_rate', 0.1) optimizer = optimization_params.get('optimizer', optim.Adam(learning_rate)) initial_params = parameterizer.init_params() optimizer = optimizer.create(jnp.array(initial_params.values)) loss_curve = [] while True: loss_value, grad = jax.value_and_grad(loss)(optimizer.target) loss_curve.append(float(loss_value)) optimizer = optimizer.apply_gradient(grad) step = len(loss_curve) if (step > min_steps and (loss_curve[-smoothing_window] - loss_curve[-1]) < epsilon): # Not much progress recently. Call it a day. break if step >= max_steps: print('Hit max step limit. You can control this exit condition by ' 'setting max_steps in optimized_params.') break if verbose and (step % 10 == 0): print(f'step {step}, loss value {loss_curve[-1]}') final_params = np.array(optimizer.target) c['final_params'] = xr.DataArray(final_params, coords=initial_params.coords) parameterizer.xr_apply_params(c, c.final_params)
def test_implicit_differentiation_versus_autodiff(self, lse_mode): epsilon = 0.05 def loss_g(a, x, implicit=True): out = sinkhorn.sinkhorn(geometry.Geometry( cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] + jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(x, self.y.T), epsilon=epsilon), a=a, b=self.b, tau_a=0.8, tau_b=0.87, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost def loss_pcg(a, x, implicit=True): out = sinkhorn.sinkhorn(pointcloud.PointCloud(x, self.y, epsilon=epsilon), a=a, b=self.b, tau_a=1.0, tau_b=0.95, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost for loss in [loss_g, loss_pcg]: loss_and_grad_imp = jax.jit( jax.value_and_grad(lambda a, x: loss(a, x, True), argnums=(0, 1))) loss_and_grad_auto = jax.jit( jax.value_and_grad(lambda a, x: loss(a, x, False), argnums=(0, 1))) loss_value_imp, grad_loss_imp = loss_and_grad_imp(self.a, self.x) loss_value_auto, grad_loss_auto = loss_and_grad_auto( self.a, self.x) self.assertAllClose(loss_value_imp, loss_value_auto) eps = 1e-3 # test gradient w.r.t. a works and gradient implicit ~= gradient autodiff delta = jax.random.uniform(self.rngs[4], (self.n, )) / 10 delta = delta - jnp.mean(delta) # center perturbation reg_ot_delta_plus = loss(self.a + eps * delta, self.x) reg_ot_delta_minus = loss(self.a - eps * delta, self.x) delta_dot_grad = jnp.sum(delta * grad_loss_imp[0]) self.assertAllClose(delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps), rtol=1e-02, atol=1e-02) # note how we removed gradients below. This is because gradients are only # determined up to additive constant here (the primal variable is in the # simplex). self.assertAllClose(grad_loss_imp[0] - jnp.mean(grad_loss_imp[0]), grad_loss_auto[0] - jnp.mean(grad_loss_auto[0]), rtol=1e-02, atol=1e-02) # test gradient w.r.t. x works and gradient implicit ~= gradient autodiff delta = jax.random.uniform(self.rngs[4], (self.n, self.dim)) reg_ot_delta_plus = loss(self.a, self.x + eps * delta) reg_ot_delta_minus = loss(self.a, self.x - eps * delta) delta_dot_grad = jnp.sum(delta * grad_loss_imp[1]) self.assertAllClose(delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps), rtol=1e-02, atol=1e-02) self.assertAllClose(grad_loss_imp[1], grad_loss_auto[1], rtol=1e-02, atol=1e-02)
masked_loss = jnp.abs(durations - x.durations) * mask loss = jnp.sum(masked_loss) / jnp.sum(mask) return loss, aux forward_fn = jax.jit(hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply) def predict_duration(params, aux, rng, x: DurationInput): d, _ = forward_fn(params, aux, rng, x) return d, x.durations val_loss_fn = jax.jit(partial(loss_fn, is_training=False)) loss_vag = jax.value_and_grad(loss_fn, has_aux=True) optimizer = optax.chain( optax.clip_by_global_norm(FLAGS.max_grad_norm), optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay) ) @jax.jit def update(params, aux, rng, optim_state, inputs: DurationInput): rng, new_rng = jax.random.split(rng) (loss, new_aux), grads = loss_vag(params, aux, rng, inputs) updates, new_optim_state = optimizer.update(grads, optim_state, params) new_params = optax.apply_updates(params, updates) return loss, (new_params, new_aux, new_rng, new_optim_state)
def _find_sigmoid_upper_bound_tangent(range_lb: Tensor, range_ub: Tensor, tol: float) -> Tensor: """Search the point where the concave hull of the sigmoid stops being linear. The concave upper bound of the sigmoid can be several things: - It can be the sigmoid itself (if the interval considered is in R+) - It can be linear (If the upper bound is small enough. This is a bit more general that just if the interval is in R-) - It can start linear and at some tangent point become sigmoid. This functions searches for the tangent point. For the other cases, another function would have narrowed the search range such that range_lb = range_ub and we early exit from the loop. This is a combination of a binary search and of the Newton method. Args: range_lb: Lower bound of the domain on which to define the convex hull. range_ub: Upper bound of the domain on which to define the convex hull. tol: Tolerance criterion for convergence Returns: final_t: Tangent point at which the concave upper bound of the sigmoid should go from linear to sigmoid. If range_lb == range_ub, that number should be returned. """ flat_range_lb = jnp.reshape(range_lb, (-1, )) flat_range_ub = jnp.reshape(range_ub, (-1, )) fun = jax.nn.sigmoid dfun = lambda x: fun(x) * (1 - fun(x)) # The point that we are looking for is the point where: # dfun(x) = (fun(x) - fun(lb)) / (x - lb) to_root_fun = lambda x, lb: dfun(x) - (fun(x) - fun(lb)) / jnp.maximum( x - lb, tol) to_root_val_and_grad = jax.vmap(jax.value_and_grad(to_root_fun)) search_lb = jnp.maximum(flat_range_lb, 0.) # In the case where l is very large (in the negative), we can have an # approximate solution. We can use this to shrink the search space, making the # binary search converge significantly faster. # If lb<-1e3, fun(lb)=0, so we can just solve: # dfun(x) = fun(x) / (x - lb). # <=> fun(X) (1 - fun(x)) = fun(x) / (x - lb) # <=> 1 - fun(x) = 1 / (x - lb) # <=> exp(-x) / (1 + exp(-x)) = 1 / (x - lb) # <=> exp(-x) * (x - lb - 1) = 1 # <=> exp(x) - x = -lb -1 # And we can assume that for large value, exp(x)-x ~ exp(x) # So we know that the optimal t is going to be close to log(-lb-1) # We add some padding (+1) around that value to make sure we are not excluding # a valid solution from the search space. upper_bound_for_large_l = jnp.where( flat_range_lb < -1e3, jnp.log(jnp.maximum(-flat_range_lb - 1, 1.)) + 1, float('inf')) search_ub = jnp.minimum(flat_range_ub, upper_bound_for_large_l) t_k = 0.5 * (search_lb + search_ub) it = jnp.array(0) def body_fun(loop_args): it, t, lb, ub = loop_args new_it = it + 1 f, df = to_root_val_and_grad(t, flat_range_lb) new_lb = jnp.where(f >= 0., jnp.maximum(lb, t), lb) new_ub = jnp.where(f <= 0., jnp.minimum(ub, t), ub) newton_t = t - f / df out_of_bounds_t = (newton_t <= new_lb) | (newton_t >= new_ub) new_t = jnp.where((jnp.abs(df) <= tol) | out_of_bounds_t, 0.5 * (new_lb + new_ub), newton_t) return new_it, new_t, new_lb, new_ub def continue_search(loop_args): it, t, lb, ub = loop_args # Points that have not converged have both # - high value on the difference between average slope and sig derivative # - high value on the gap between upper bound and lower bound # If any one of this criterion is not satisfied, the point has converged. not_converged = ((jnp.abs(to_root_fun(t, flat_range_lb)) >= tol) & (jnp.abs(ub - lb) >= tol)) # We keep searching as long as: # - we don't exceed 100 iterations # - There is at least 1 point that has not converged. return jnp.logical_and(it <= 100, jnp.any(not_converged)) _, final_t, _, _ = jax.lax.while_loop(continue_search, body_fun, (it, t_k, search_lb, search_ub)) final_t = jnp.reshape(final_t, range_lb.shape) return final_t
def gradient_descent_line_search_step(data, loss_f, model_param, options): """Gradient Descent optimization with line search step. Args: data: A tuple of inputs and labels passed to the loss function. loss_f: The loss function that takes in model_param, inputs, and labels. model_param: Current model parameters to be passed to loss_f. options: A dictionary of optimizer specific hyper-parameters. Returns: Updated model parameters and updated step size. """ options = dict(options) beta = options.get('beta', 0.9) beta_prime = options.get('beta_prime', 1e-4) step_size = options.get('step_size', 10000.0) verbose = options.get('verbose', False) reuse_last_step = options.get('reuse_last_step', False) inputs, labels = data[0], data[1] loss_with_data_f = lambda param: loss_f(param, inputs, labels) value_and_grad_f = jax.value_and_grad(loss_with_data_f) value, grad = value_and_grad_f(model_param) # Maximum learning rate allowed from Theorem 5 in Gunasekar et al. 2017 if options['bound_step']: # Bound by dual of L2 b_const = jnp.max(jnp.linalg.norm(inputs, ord=2, axis=0)) step_size = min(step_size, 1 / (b_const * b_const * value)) grad, unravel_fn = ravel_pytree(grad) x, unravel_fn = ravel_pytree(model_param) # If we normalize step_size will be harder to tune. direction = -grad # TODO(fartash): consider using the condition in FISTA def next_candidate(step_size): next_iter = x + step_size * direction next_value, next_grad = value_and_grad_f(unravel_fn(next_iter)) next_grad, _ = ravel_pytree(next_grad) return next_iter, next_value, next_grad def stop_cond(step_size, res): _, next_value, next_grad = res gd = jnp.sum(grad * direction) # Strong Wolfe condition. cond1 = next_value <= value + beta_prime * step_size * gd cond2 = jnp.sum(jnp.abs(next_grad * direction)) >= beta * gd return cond1 and cond2 step_size, res = backtracking(next_candidate, stop_cond, step_size, options=options) next_param = res[0] if reuse_last_step: options['step_size'] = step_size if verbose: logging.info('Step size: %f', step_size) return unravel_fn(next_param), options
def update(params, batch, opt_state): value, grads = value_and_grad(loss)(params, batch) opt_state = opt_update(0, grads, opt_state) return get_params(opt_state), opt_state, value
import copt.penalty X, y = datasets.make_regression() n_samples, n_features = X.shape def loss(w): """Squared error loss.""" z = np.dot(X, w) - y return np.sum(z * z) / n_samples # .. use JAX to compute the gradient of loss value_and_grad .. # .. returns both the gradient and the objective, which is .. # .. the format that COPT accepts .. f_grad = jax.value_and_grad(loss) w0 = onp.zeros(n_features) l1_ball = copt.penalty.L1Norm(0.1) cb = cp.utils.Trace(lambda x: loss(x) + l1_ball(x)) sol = cp.minimize_proximal_gradient(f_grad, w0, prox=l1_ball.prox, callback=cb, jac=True) plt.plot(cb.trace_fx, lw=3) plt.yscale("log") plt.xlabel("# Iterations") plt.ylabel("Objective value") plt.grid()
def __post_init__(self): # Pre-compile useful functions. self._value_and_grad_fun = jax.value_and_grad(self.fun, has_aux=self.has_aux)
def train(target_network, optimizer, states, actions, next_states, rewards, terminals, num_tau_samples, num_tau_prime_samples, num_quantile_samples, cumulative_gamma, double_dqn, kappa, rng, mico_weight, distance_fn, tau, alpha, clip_value_min): """Run a training step.""" def loss_fn(model, rng_input, target_quantile_vals, target_r, target_next_r): model_output = jax.vmap( lambda m, x, y, z: m(x=x, num_quantiles=y, rng=z), in_axes=(None, 0, None, None))( model, states, num_tau_samples, rng_input) quantile_values = model_output.quantile_values quantiles = model_output.quantiles representations = model_output.representation representations = jnp.squeeze(representations) chosen_action_quantile_values = jax.vmap(lambda x, y: x[:, y][:, None])( quantile_values, actions) # Shape of bellman_erors and huber_loss: # batch_size x num_tau_prime_samples x num_tau_samples x 1. bellman_errors = (target_quantile_vals[:, :, None, :] - chosen_action_quantile_values[:, None, :, :]) # The huber loss (see Section 2.3 of the paper) is defined via two cases: # case_one: |bellman_errors| <= kappa # case_two: |bellman_errors| > kappa huber_loss_case_one = ( (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 * bellman_errors ** 2) huber_loss_case_two = ( (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) huber_loss = huber_loss_case_one + huber_loss_case_two # Tile by num_tau_prime_samples along a new dimension. Shape is now # batch_size x num_tau_prime_samples x num_tau_samples x 1. # These quantiles will be used for computation of the quantile huber loss # below (see section 2.3 of the paper). quantiles = jnp.tile(quantiles[:, None, :, :], [1, num_tau_prime_samples, 1, 1]).astype(jnp.float32) # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient( (bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa # Sum over current quantile value (num_tau_samples) dimension, # average over target quantile value (num_tau_prime_samples) dimension. # Shape: batch_size x num_tau_prime_samples x 1. quantile_huber_loss = jnp.sum(quantile_huber_loss, axis=2) quantile_huber_loss = jnp.mean(quantile_huber_loss, axis=1) online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances( target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * quantile_huber_loss + mico_weight * metric_loss) return jnp.mean(loss), (jnp.mean(quantile_huber_loss), metric_loss) if tau is None: rng, target_quantile_vals, target_r, target_next_r = target_quantile_values( optimizer.target, target_network, states, next_states, rewards, terminals, num_tau_prime_samples, num_quantile_samples, cumulative_gamma, double_dqn, rng) else: rng, target_quantile_vals, target_r, target_next_r = ( munchausen_target_quantile_values( target_network, states, actions, next_states, rewards, terminals, num_tau_prime_samples, num_quantile_samples, cumulative_gamma, rng, tau, alpha, clip_value_min)) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) rng, rng_input = jax.random.split(rng) all_losses, grad = grad_fn(optimizer.target, rng_input, target_quantile_vals, target_r, target_next_r) loss, component_losses = all_losses quantile_loss, metric_loss = component_losses optimizer = optimizer.apply_gradient(grad) return rng, optimizer, loss, quantile_loss, metric_loss
def main(): with open('reuters_vocab.pkl', 'rb') as f: vocab = pickle.load(f) v_dim = len(vocab['num_to_word']) e_dim = 1024 prng_key = random.PRNGKey(0xdeadbeef) words = jnp.zeros(v_dim, dtype=jnp.float32) word_ix = random.uniform(prng_key, (1, )) * v_dim word_ix = int(jnp.floor(word_ix)[0]) words = jops.index_update(words, word_ix, 1.0) # first create the architecture of the model mdl = Embedding(v_dim, e_dim) # then complete the model spec by giving an example input tensor params = mdl.init(prng_key, words) # now apply the params to the model with the input out = mdl.apply(params, words) print(f'out: {out}') print(f'shape: {out.shape}') # let's train the model on nltk's reuters dataset from nltk.corpus import reuters train_texts = [] for fname in reuters.fileids(): text = reuters.words(fname) train_texts.append(text) # now generate word-context elements window_size = 2 word_pairs = [] for words in tqdm(train_texts, desc='make train set'): for word_ix, word in enumerate(words): for offset in range(1, window_size + 1): back_context = word_ix - offset if back_context >= 0: word_pairs.append((word, words[back_context])) fwd_context = word_ix + offset if fwd_context < len(words): word_pairs.append((word, words[fwd_context])) # convert words to vocab IDs w2n = vocab['word_to_num'] id_pairs = [] for word_pair in tqdm(word_pairs, desc='gen word pairs'): word = word_pair[0] context = word_pair[1] if word in w2n and context in w2n: w_id, c_id = w2n[word], w2n[context] id_pairs.append((w_id, c_id)) id_pairs = jnp.array(id_pairs) print(f'train pairs: {len(id_pairs)}') # run grad desc id_pairs = id_pairs[0:len(id_pairs) // 100] lr = 0.3 batch_size = 2500 # TEST: what if I run one at a time? ''' loss_fn = lambda x, y : nll_loss_fn(mdl, params, x, y) grad_fn = jax.value_and_grad(loss_fn) grad_calc_fn = lambda params, x, y : grad_fn(params, x, y) param_update_fn = lambda old, grad: old - lr * grad template_vec = jnp.zeros(v_dim, dtype=jnp.float32) for epoch in trange(5): for pair in tqdm(id_pairs): x = jops.index_update(template_vec, pair[0], 1.) y = jops.index_update(template_vec, pair[1], 1.) loss_val, grad = grad_calc_fn(params, x, y) params = jax.tree_multimap(param_update_fn, paramd, grad) import pdb; pdb.set_trace() pass ''' # TEST END batches = jnp.split(id_pairs, jnp.arange(batch_size, len(id_pairs), batch_size)) for epoch in trange(1): # TODO: shuffle & batch id_pairs pbar = trange(len(batches), desc=f'epoch:--- - loss:------') for batch in batches: x_vals, y_vals = __id_to_one_hot(batch, v_dim) loss_fn = nll_loss_fn(mdl, params, x_vals, y_vals) grad_fn = jax.value_and_grad(loss_fn) loss_val, grad = grad_fn(params) params = jax.tree_multimap(lambda old, grad: old - lr * grad, params, grad) pbar.set_description(f'epoch:{epoch:03d} - loss:{loss_val:0.4f}') pbar.update() import pdb pdb.set_trace() print('done!')