def _check_args_and_kwargs(self, *args, **kwargs): for k in kwargs.keys(): if k not in GraphV2.allowed_kwargs: raise ValueError(f"unknown argument {k}") consumed_args = 0 if kwargs.get("inputs", False): inputs = jax.tree_flatten(kwargs["inputs"])[0] else: if len(args) == 0: raise ValueError("Input is not passed correctly") else: inputs = jax.tree_flatten(args[consumed_args])[0] consumed_args += 1 if kwargs.get("outputs", False): outputs = jax.tree_flatten(kwargs["outputs"])[0] else: if len(args) == 0: raise ValueError("Output is not passed correctly") else: outputs = jax.tree_flatten(args[consumed_args])[0] self._output_names = [o.name for o in outputs] return inputs, outputs
def vectors_to_blocks( self, parameter_structured_vector: Any, ) -> Sequence[BlockVector]: """Splits the parameters to values for the corresponding blocks.""" in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars) params_vars = in_vars[self.params_index] params_vars_flat = jax.tree_flatten(params_vars)[0] params_values_flat = jax.tree_flatten(parameter_structured_vector)[0] assert len(params_vars_flat) == len(params_values_flat) params_dict = dict(zip(params_vars_flat, params_values_flat)) per_block_vectors = [] for eqn in self._layer_tags: if eqn.primitive.name == "generic_tag": block_vars = eqn.invars else: block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2] per_block_vectors.append( tuple(params_dict.pop(v) for v in block_vars)) if params_dict: raise ValueError( f"From the parameters the following structure is not " f"assigned to any block: {params_dict}. Most likely " f"this part of the parameters is not part of the graph " f"reaching the losses.") return tuple(per_block_vectors)
def assertStructureAllClose(self, x, y, **kwargs): x_v, x_tree = jax.tree_flatten(x) y_v, y_tree = jax.tree_flatten(y) self.assertEqual(x_tree, y_tree) for xi, yi in zip(x_v, y_v): self.assertEqual(xi.shape, yi.shape) self.assertAllClose(xi, yi, check_dtypes=True, **kwargs)
def tree_allclose(*trees): """ Determines if all elements of `trees` a) have the same tree structure and b) the corresponding leaves all fulfill ``np.allclose(leaf1, leaf2)`` such that the trees are, up to numerical tolerances of `np.allclose`, equal """ if len(trees) > 2: return tree_allclose(trees[0], trees[1]) and tree_allclose(*trees[1:]) if len(trees) < 2: return True tree1, tree2 = trees _, tree_def1 = tree_flatten(tree1) _, tree_def2 = tree_flatten(tree2) if tree_def1 != tree_def2: return False return np.all( tree_flatten( tree_multimap(lambda arr1, arr2: np.allclose(arr1, arr2), tree1, tree2))[0])
def update(self, model_gradient, old_optimizer, new_optimizer): """Computes a number of statistics from the model params and update. Statistics computed: Per layer update variances and norms. Per layer gradient variance and norms. Per layer param norms. Ratio of parameter norm to update and update variance. Args: model_gradient: A pytree of the same shape as the model_params pytree that was used when the metrics_grabber object was created. old_optimizer: The optimizer before the param update. new_optimizer: The optimizer after the param update. Returns: An updated class object. """ grads_flat, treedef = jax.tree_flatten(model_gradient) new_params_flat, _ = jax.tree_flatten(new_optimizer.target.params) old_params_flat, _ = jax.tree_flatten(old_optimizer.target.params) # flatten_up_to here is needed to avoid flattening the _MetricsLeafState # nodes. state_flat = treedef.flatten_up_to(self.state) new_states_flat = [ _update_param_stats(state, grad, new_param - old_param, new_param, self.config) for state, grad, old_param, new_param in zip( state_flat, grads_flat, old_params_flat, new_params_flat) ] return self.replace(state=jax.tree_unflatten(treedef, new_states_flat))
def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) accuracy = jnp.mean( jnp.equal(jnp.argmax(logits, axis=-1), jnp.argmax(labels, axis=-1))) return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels), accuracy grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (l, train_accuracy), g = grad_fn(opt.target, images, labels) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['accuracy'] = train_accuracy # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = config.get('weight_decay', []) or [] if isinstance(decay_rules, numbers.Number): decay_rules = [('.*kernel.*', decay_rules)] sched_m = lr / config.lr.base if config.get( 'weight_decay_decouple') else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=train_utils.tree_map_with_regex( decay_fn, opt.target, decay_rules)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements
def _check_output(self, actual, expected): expected_leaves, expected_structure = jax.tree_flatten(expected) actual_leaves, actual_structure = jax.tree_flatten(actual) assert all(isinstance(x, jnp.ndarray) for x in expected_leaves), "bad example_data" if actual_structure != expected_structure: raise TypeError( f"func has bad return tree_structure, expected: {expected_structure}, " f"got: {actual_structure}") if not all(isinstance(x, jnp.ndarray) for x in actual_leaves): bad_types = tuple( type(x) for x in actual_leaves if not isinstance(x, jnp.ndarray)) raise TypeError( "all leaves of dist_params must be of type: jax.numpy.ndarray, " f"found leaves of type: {bad_types}") if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)): shapes_tree = jax.tree_multimap( lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}", actual, expected) raise TypeError( f"found leaves with unexpected shapes: {shapes_tree}")
def test_jit_pytree_return(self): @iree.jax.jit def apply_sqrt(pytree): return jax.tree_map(jnp.sqrt, pytree) np.random.seed(0) input_tree = { "a": [ normal((2, 3)), { "b": normal(3) }, ], "c": ( { "d": [normal(2), normal(3)] }, (normal(1), normal(4)), ) } expected = jax.tree_map(jnp.sqrt, input_tree) expected_arrays, expected_tree = jax.tree_flatten(expected) result = apply_sqrt(input_tree) result_arrays, result_tree = jax.tree_flatten(result) self.assertEqual(expected_tree, result_tree) for expected_array, result_array in zip(expected_arrays, result_arrays): np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
def __local_cost_and_grad_function(local_cost_fun, dtype, logpsi, pars, *args): costfun_outdtype = _outdtype[local_cost_fun] lcfun_u = _unjitted_fun[local_cost_fun] if dtype is complex: der_local_cost_fun = jax.value_and_grad(lcfun_u, argnums=1, holomorphic=True) return der_local_cost_fun(logpsi, pars, *args) else: if costfun_outdtype is complex: _costfun_re = lambda w: lcfun_u(logpsi, w, *args).real _costfun_im = lambda w: lcfun_u(logpsi, w, *args).imag # Pullbacks cost_val_re, cost_vjp_re = jax.vjp(_costfun_re, pars) cost_val_im, cost_vjp_im = jax.vjp(_costfun_im, pars) cost_val = cost_val_re + 1.0j * cost_val_im primal = jax.numpy.ones(cost_val.shape) # Apply pullbacks to primals cost_grad_re, tree_fun = jax.tree_flatten(cost_vjp_re(primal)[0]) cost_grad_im, _ = jax.tree_flatten(cost_vjp_im(primal)[0]) out_flat = [re + 1.0j * im for re, im in zip(cost_grad_re, cost_grad_im)] grad_c = jax.tree_unflatten(tree_fun, out_flat) return (cost_val, grad_c) else: der_local_cost_fun = jax.value_and_grad(lcfun_u, argnums=1) return der_local_cost_fun(logpsi, pars, *args)
def _local_value_and_grad_notcentered_kernel( logpsi, pars, vp, mel, v, real_to_complex=False ): # can use if with jit because that argument is exposed statically to the jit! if real_to_complex: logpsi_vp_r, f_vjp_r = jax.vjp(lambda w: (logpsi(w, vp).real), pars) logpsi_vp_j, f_vjp_j = jax.vjp(lambda w: (logpsi(w, vp).imag), pars) logpsi_vp = logpsi_vp_r + 1.0j * logpsi_vp_j vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) vec_r = vec.real vec_j = vec.imag loc_val = vec.sum() vr_grad_r, tree_fun = jax.tree_flatten(f_vjp_r(vec_r)[0]) vj_grad_r, _ = jax.tree_flatten(f_vjp_r(vec_j)[0]) vr_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_r)[0]) vj_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_j)[0]) r_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_r, vj_grad_r)] j_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_j, vj_grad_j)] out_flat = [re + 1.0j * im for re, im in zip(r_flat, j_flat)] grad_c = jax.tree_unflatten(tree_fun, out_flat) else: logpsi_vp, f_vjp = jax.vjp(lambda w: logpsi(w, vp), pars) vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) loc_val = vec.sum() grad_c = f_vjp(vec)[0] return loc_val, grad_c
def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split( rng_model_local, num=3) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={ 'dropout': rng_model_local, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements
def update_fn(opt, lr, images, labels, rng): """Update step. Copy to deterministic_utils.py whenever changes are made!""" measurements = {} # Split rng and return next_rng for the following step. rng, next_rng = jax.random.split(rng, 2) rng_local = jax.random.fold_in(rng, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_local}) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) return loss, logits # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, logits), g = train_utils.accumulate_gradient( jax.value_and_grad(loss_fn, has_aux=True), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') measurements['training_loss'] = l # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) top1_idx = jnp.argmax(logits, axis=1) top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] prec1 = jax.lax.psum(jnp.sum(top1_correct), axis_name='batch') / batch_size measurements['training_prec@1'] = prec1 measurements['learning_rate'] = lr return opt, next_rng, measurements
def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={'dropout': rng_model_local}) return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_util.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) decay_rules = weight_decay or [] if isinstance(decay_rules, numbers.Number): decay_rules = [('.*kernel.*', decay_rules)] sched_m = lr / config.lr.base if config.get( 'weight_decay_decouple') else lr def decay_fn(v, wd): return (1.0 - sched_m * wd) * v opt = opt.replace(target=train_utils.tree_map_with_regex( decay_fn, opt.target, decay_rules)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements
def tree_vmap(f, lst): stacked = jax.tree_map(lambda args: jnp.stack(args), lst) out_stacked = jax.vmap(f)(stacked) _, outer_treedef = jax.tree_flatten([None] * len(lst)) _, inner_treedef = jax.tree_flatten(out_stacked) out_unstacked_transposed = jax.tree_map(list, out_stacked) out_unstacked = jax.tree_transpose( outer_treedef, inner_treedef, out_unstacked_transposed ) return out_unstacked
def check_eq(xs, ys, atol=None, rtol=None): xs_leaves, xs_tree = jax.tree_flatten(xs) ys_leaves, ys_tree = jax.tree_flatten(ys) assert xs_tree == ys_tree, "Tree shapes don't match." assert jax.tree_util.tree_all( jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape, xs_leaves, ys_leaves)), "Leaves' shapes don't match." assert jax.tree_multimap( partial(_assert_numpy_allclose, atol=atol, rtol=rtol), xs_leaves, ys_leaves)
def __call__(self, params: Params) -> float: """Evaluates the regularizer.""" params = self._preprocess_fn(params) leaves, _ = jax.tree_flatten(params) if self._param_weights: param_weight_leaves, _ = jax.tree_flatten(self._param_weights) return sum( jnp.vdot(pw, jnp.square(x)) for pw, x in zip(param_weight_leaves, leaves)) * self._weight return sum(jnp.vdot(x, x) for x in leaves) * self._weight
def test_load_state(self): temp_dir = self.get_temp_dir() path = os.path.join(temp_dir, 'state') init_state = test_util.create_mock_state() serialization.save_state(init_state, path) state = serialization.load_state(path) expected_flat, expected_tree_def = jax.tree_flatten(init_state) actual_flat, actual_tree_def = jax.tree_flatten(state) for expected_array, actual_array in zip(expected_flat, actual_flat): self.assertAllEqual(expected_array, actual_array) self.assertEqual(expected_tree_def, actual_tree_def)
def assertAgentParametersEqual(self, agent1: sac_agent.SACAgent, agent2: sac_agent.SACAgent): agent1_params = get_agent_params(agent1) agent2_params = get_agent_params(agent2) agent1_params, agent1_structure = jax.tree_flatten(agent1_params) agent2_params, agent2_structure = jax.tree_flatten(agent2_params) self.assertEqual(agent1_structure, agent2_structure, 'Parameter structures do not match.') for param1, param2 in zip(agent1_params, agent2_params): if (param1 != param2).any(): self.fail(f'Parameters are not equal: {param1}, {param2}')
def test_apply_if_finite(self, opt_builder): one = jnp.ones([]) nan = jnp.array(jnp.nan) def fn(x): return x * hk.get_parameter('p', [], init=hk.initializers.Constant(0.)) fn = hk.without_apply_rng(hk.transform(fn)) params = fn.init(jax.random.PRNGKey(1905), one) opt = wrappers.apply_if_finite(opt_builder(), 2) state = opt.init(params) grads_fn = jax.grad(self.variant(fn.apply)) # Do one successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) # We know exactly what should be the value of params since we are # effectively using sgd in all cases. self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Check 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Again 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next param update with NaN is accepted since we reached maximum grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0]))) self.assertEqual(5, int(state.total_notfinite))
def train_benchmark(state): train_iter = iter(train_data) example = next(train_iter) params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params, net_state, next(rng), opt_state, metrics_state, *example) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]] while state: params_, net_state_, opt_state_, metrics_state_, *_ = train_step( params_, net_state_, next(rng), opt_state_, metrics_state_, *example) example = next(train_iter) [x.block_until_ready() for x in jax.tree_flatten(params_)[0]]
def backward(ctx, *tangents): tangents = jax.tree_map( map_if(is_torch_tensor)(to_jax_ndarray), tangents) tangents = jax.tree_unflatten(ctx.result_tree, tangents[:-1]) grads = ctx.fun_vjp(tangents) return (None, *jax.tree_flatten( jax.tree_map(map_if(is_jax_ndarray)(to_torch_tensor), grads))[0])
def forward(ctx, fn, *args): args = jax.tree_map(map_if(is_torch_tensor)(to_jax_ndarray), args) result, ctx.fun_vjp = jax.vjp(fn, *args) result_flat, result_tree = jax.tree_flatten(result) ctx.result_tree = result_tree return (*jax.tree_map( map_if(is_jax_ndarray)(to_torch_tensor), result_flat), result_tree)
def update_fn(grads, state, params): """Transform the input gradient and update all statistics. Args: grads: the gradient tensors for the parameters. state: a named tuple containing the state of the optimizer params: the parameters that should be updated. Returns: A tuple containing the new parameters and the new optimizer state. """ params_flat, treedef = jax.tree_flatten(params) stats_flat = treedef.flatten_up_to(state.stats) grads_flat = treedef.flatten_up_to(grads) new_stats_flat = jax.tree_multimap( lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) outputs = jax.tree_multimap( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_unflatten(treedef, updates_flat) new_stats = jax.tree_unflatten(treedef, new_stats_flat) new_state = ShampooState( count=state.count+1, stats=new_stats) return updates, new_state
def test_simple_mask_two_layer(self): """Tests generation of a simple mask.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros, ['kernel']) result, _ = jax.tree_flatten( jax.tree_util.tree_multimap(lambda x, *xs: (x == xs[0]).all(), mask, gen_mask)) self.assertTrue(all(result))
def apply_gradient(self, hyper_params, params, state, grads): p_leaves, treedef = jax.tree_flatten(params) s_leaves = treedef.flatten_up_to(state.param_states) g_leaves = treedef.flatten_up_to(grads) split_grads = zip(*(self._split_grad(p, s, g, hyper_params.wn_decay) for p, s, g in zip(p_leaves, s_leaves, g_leaves))) d_p, d_s, d_g, s_p, s_s, s_g = [ jax.tree_unflatten(treedef, x) for x in split_grads ] wn_params = {'direction': d_p, 'scale': s_p} wn_state = {'direction': d_s, 'scale': s_s} wn_grads = {'direction': d_g, 'scale': s_g} new_wn_params, new_state = self.wrapped_optimizer.apply_gradient( hyper_params.inner, wn_params, state.replace(param_states=wn_state), wn_grads) directions = treedef.flatten_up_to(new_wn_params['direction']) scales = treedef.flatten_up_to(new_wn_params['scale']) new_params, mults = zip(*(self._merge_param(d, s, hyper_params.wn_eps) for d, s in zip(directions, scales))) new_params = jax.tree_unflatten(treedef, new_params) mults = jax.tree_unflatten(treedef, mults) direction_state = new_state.param_states['direction'] scale_state = new_state.param_states['scale'] param_states = jax.tree_multimap( lambda _, *args: _WeightNormParamState(*args), params, direction_state, scale_state, mults) return new_params, new_state.replace(param_states=param_states)
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False): if length is None: length = tree_flatten(xs)[0][0].shape[0] if enum: return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack) def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) with handlers.block(): seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y) return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def apply_gradient(self, hyper_params, params, state, grads, hessians): """Applies a gradient for a set of parameters. Args: hyper_params: a named tuple of hyper parameters. params: the parameters that should be updated. state: a named tuple containing the state of the optimizer grads: the gradient tensors for the parameters. hessians: the hessian tensors for the parameters. Returns: A tuple containing the new parameters and the new optimizer state. """ step = state.step params_flat, treedef = jax.tree_flatten(params) states_flat = treedef.flatten_up_to(state.param_states) grads_flat = treedef.flatten_up_to(grads) hessians_flat = treedef.flatten_up_to(hessians) out = [ self.apply_param_gradient(step, hyper_params, param, state, grad, hessian) for param, state, grad, hessian in zip(params_flat, states_flat, grads_flat, hessians_flat) ] new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) new_params = jax.tree_unflatten(treedef, new_params_flat) new_param_states = jax.tree_unflatten(treedef, new_states_flat) new_state = OptimizerState(step + 1, new_param_states) return new_params, new_state
def forward_compute_losses( params_primals: Any, ) -> Sequence[Sequence[jnp.ndarray]]: primals_[params_index] = params_primals flat_args = jax.tree_flatten(primals_)[0] # Mapping from variable -> value env = dict() read = functools.partial(tgm.read_env, env) write = functools.partial(tgm.write_env, env) # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, flat_args) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` losses_so_far = 0 loss_tags = [] for eqn in jaxpr.eqns: tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) if isinstance(eqn.primitive, tags.LossTag): loss_tags.append(eqn) losses_so_far += 1 if num_losses is not None and losses_so_far == num_losses: break return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags)
def vjp_func(tangents): all_tangents = aux_vjp(tangents) tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] inputs_tangents = jax.tree_flatten(inputs_tangents)[0] tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) read_primals = functools.partial(tgm.read_env, primals_dict) read_tangents = functools.partial(tgm.read_env, tangents_dict) layers_info = [] for jaxpr_eqn in layer_tags: layer_tag = _unbox_layer_tag(jaxpr_eqn) info = dict() primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) ( info["outputs"], info["inputs"], info["params"], ) = layer_tag.split_all_inputs(primals) tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) ( info["outputs_tangent"], info["inputs_tangent"], info["params_tangent"], ) = layer_tag.split_all_inputs(tangents) layers_info.append(info) return tuple(layers_info)