def logprior(params): # Spherical Gaussian prior leaves_of_params = tree_leaves(params) return sum( tree_map( lambda p: jnp.sum( jax.scipy.stats.norm.logpdf(p, scale=l2_regularizer)), leaves_of_params))
def kl_fn(params): n_params = sum([p.size for p in jax.tree_leaves(params)]) sigma_prior = jnp.sqrt(1 / weight_decay) mu_vi_tree = params["mean"] sigma_vi_tree = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"]) def get_parameter_kl(mu_vi, sigma_vi): return (jnp.log(sigma_prior / sigma_vi) + (sigma_vi**2 + mu_vi**2) / 2 / sigma_prior**2 - 1 / 2) kl_tree = jax.tree_multimap(get_parameter_kl, mu_vi_tree, sigma_vi_tree) kl = sum([p_kl.sum() for p_kl in jax.tree_leaves(kl_tree)]) return -kl * temperature
def _measure_and_maybe_clip_grad(grad, metrics, clipped_grad_norm = None): """Records and optionally clips gradient.""" grad_l2_sum = sum([jnp.sum(x**2) for x in jax.tree_leaves(grad)]) metrics["unclipped_grad_l2_sum"] = grad_l2_sum if clipped_grad_norm is not None: # Clip gradients after pmean aggregation grad = jax.experimental.optimizers.clip_grads(grad, clipped_grad_norm) metrics["clipped_grad_l2_sum"] = sum( [jnp.sum(x**2) for x in jax.tree_leaves(grad)]) else: # Clipped grad same as unclipped grad metrics["clipped_grad_l2_sum"] = grad_l2_sum return grad, metrics
def tree_count_infs_nans(tree, psum_axis_name=None): leaves = jax.tree_leaves(tree) num_infs = sum(jnp.sum(jnp.isinf(x)) for x in leaves) num_nans = sum(jnp.sum(jnp.isnan(x)) for x in leaves) if psum_axis_name: num_infs, num_nans = jax.lax.psum((num_infs, num_nans), axis_name=psum_axis_name) return num_infs, num_nans
def maybe_compute_and_add_sql2_to_metrics_dict(variable_tree, norm_tree_sql2, key): if variable_tree and not norm_tree_sql2: norm_tree_sql2 = self._tree_norm_fn_sql2(variable_tree) if norm_tree_sql2: metrics_dict['{}_norms_sql2'.format(key)] = norm_tree_sql2 metrics_dict['global_{}_norm_sql2'.format(key)] = sum( jax.tree_leaves(norm_tree_sql2))
def lazy_fn(*inputs): nonlocal output_traced, master leaves = jax.tree_leaves(inputs) if leaves: master = leaves[0]._trace.master # pylint: disable=protected-access output = fn(*(inputs + args), **kwargs) output_traced = output return output
def log_prior(params): """Computes the Gaussian prior log-density.""" # ToDo izmailovpavel: make temperature treatment the same as in gaussian # likelihood function. n_params = sum([p.size for p in jax.tree_leaves(params)]) log_prob = -(0.5 * tree_utils.tree_dot(params, params) * weight_decay + 0.5 * n_params * jnp.log(weight_decay / (2 * math.pi))) return log_prob / temperature
def _test_model_params(self, model_name: str, image_size: int, expected_params: int): module = efficientnet.get_efficientnet_module(model_name) model, _ = load_model.create_image_model(jax.random.PRNGKey(0), batch_size=1, image_size=image_size, module=module) num_params = sum(np.prod(e.shape) for e in jax.tree_leaves(model)) self.assertEqual(num_params, expected_params)
def test_apply_updates_mixed_precision(self): params = ({ 'a': jnp.ones((3, 2), dtype=jnp.bfloat16) }, jnp.ones((1, ), dtype=jnp.bfloat16)) grads = jax.tree_map(lambda t: (2 * t).astype(jnp.float32), params) new_params = self.variant(update.apply_updates)(params, grads) for leaf in jax.tree_leaves(new_params): assert leaf.dtype == jnp.bfloat16
def wrapped(x, *args, **kwargs): for ys in jax.tree_leaves(args): assert ys.shape[0] == num_layers return _LayerStackWithPerLayer( f, num_layers, unroll=unroll, pass_reverse_to_layer_fn=pass_reverse_to_layer_fn, name=name)(x, *args, **kwargs)
def tree_leaf_iscomplex(pars): """ Returns true if at least one leaf in the tree has complex dtype. """ def _has_complex_dtype(x): # Returns true if x is complex return jnp.issubdtype(x.dtype, jnp.complexfloating) return any(jax.tree_leaves(jax.tree_map(_has_complex_dtype, pars)))
def maybe_get_axis(axis: int, arrays: Any) -> Optional[int]: """Returns `array.shape[axis]` for one of the arrays in the input.""" shapes = [a.shape for a in jax.tree_leaves(arrays)] sizes = {s[axis] for s in shapes} if len(sizes) != 1: raise ValueError("Arrays must have the same mapped axis size, found " f"sizes {sizes} for input shapes {shapes}") size, = sizes return size
def init_fn(params): count = jnp.zeros([], jnp.int32) dtype = getattr(next(iter(jax.tree_leaves(params)), None), 'dtype', None) hparams = { k: jnp.asarray(_convert_floats(v, dtype)) for k, v in numeric_hps.items()} hparams.update(schedule_fn(count, dtype)) return InjectHyperparamsState( # pylint:disable=too-many-function-args count, hparams, inner_factory(**other_hps, **hparams).init(params))
def static_unroll(core, inputs, state): """Unroll core over inputs, starting from state.""" outs = [] num_steps = jax.tree_leaves(inputs)[0].shape[0] for t in range(num_steps): next_input = jax.tree_map(lambda x, t=t: x[t], inputs) out, state = core(next_input, state) outs.append(out) return jnp.stack(outs), state
def _level_of_value(xs): """Returns the tracer level associated with a value if any.""" xs = jax.tree_leaves(xs) max_level = float('-inf') for x in xs: if hasattr(x, '_trace'): level = _trace_level(x._trace.master) max_level = max(level, max_level) return max_level
def __call__(self, inputs, state): inputs, should_reset = inputs # Protect against an easy, invisible error class. This should be jitted out. # >>> np.where(np.asarray([False, True]), np.zeros([2,2]), np.ones([2,2])) # ... array([[1., 0.], [1., 0.]]) # Using a should_reset of rank R - 1 could result in one example # affecting another. for x in jax.tree_leaves(state): if len(x.shape) - 1 != len(should_reset.shape): raise ValueError("should_reset must have rank-1 of state.") should_reset = jnp.expand_dims(should_reset, axis=-1) batch_size = jax.tree_leaves(inputs)[0].shape[0] initial_state = jax.tree_map(lambda v: v.astype(inputs.dtype), self.initial_state(batch_size)) state = jax.tree_multimap(lambda i, s: jnp.where(should_reset, i, s), initial_state, state) return self._core(inputs, state)
def scan(f, init, xs, length=None, reverse=False): """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out.""" if not base.inside_transform(): raise ValueError( "hk.scan() should not be used outside of hk.transform(). " "Use jax.scan() instead.") if length is None: length = jax.tree_leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() if running_init_fn: # During `init` we need to unroll one step of the scan, this is because our # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype), y0) return init, y0 if reverse: x0 = jax.tree_map(lambda x: x[-1], xs) xs = jax.tree_map(lambda x: x[:-1], xs) else: x0 = jax.tree_map(lambda x: x[0], xs) xs = jax.tree_map(lambda x: x[1:], xs) init, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(): carry, out = f(carry, x) carry = (carry, internal_state(params=False)) return carry, out # We know that we don't need to thread params in and out, since for init we # have already created them (given that above we unroll one step of the scan) # and for apply we know they are immutable. As such we only need to thread the # state and rng in and out. init = (init, internal_state(params=False)) (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse) update_internal_state(state) if running_init_fn: ys = jax.tree_multimap(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys
def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree): if kwargs and (in_axes != 0 or static_argnums): raise ValueError( "Do not use kwargs with `in_axes` or `static_argnums` " "in pmapped function.") devices_ = list(devices or jax.devices(backend)) n_devices_ = n_devices or len(devices_) devices_ = devices_[:n_devices_] if len(devices_) != n_devices_: raise ValueError( "Number of available devices is less than required for " f"test ({len(devices_)} < {n_devices_})") bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_, ) + jnp.array(x). shape) if broadcast_args_to_devices: args = [ tree_map(bcast_fn, arg) if idx not in static_argnums else arg for idx, arg in enumerate(args) ] kwargs = tree_map(bcast_fn, kwargs) else: # Pmappable axes size must be equal to number of devices. in_axes_ = in_axes if isinstance(in_axes, (tuple, list)) else [in_axes] * len(args) is_pmappable_arg = [ idx not in static_argnums and in_axes_[idx] is not None for idx in range(len(args)) ] for is_pmappable_arg, arg in zip(is_pmappable_arg, args): if not is_pmappable_arg: continue if not all(x.shape[0] == n_devices_ for x in jax.tree_leaves(arg)): shapes = tree_map(jnp.shape, arg) raise ValueError( f"Pmappable arg axes size must be equal to number of devices, " f"got: {shapes} (expected the first dim to be {n_devices_}). " "Consider setting `broadcast_args_to_devices=True`.") new_kwargs = dict(axis_name=axis_name, devices=devices_, in_axes=in_axes, static_broadcasted_argnums=static_argnums, backend=backend) # Re-compile fn if kwargs changed. nonlocal pmap_kwargs nonlocal pmapped_fn if new_kwargs != pmap_kwargs: pmap_kwargs = new_kwargs pmapped_fn = jax.pmap(fn, **pmap_kwargs) res = pmapped_fn(*args, **kwargs) return reduce_fn(res)
def _level_of_value(xs): """Returns the tracer level associated with a value if any.""" xs = jax.tree_leaves(xs) max_level = float('-inf') # TODO(jheek): consider re-introducing the tracer check # for x in xs: # if hasattr(x, '_trace'): # level = _trace_level(x._trace.master) # max_level = max(level, max_level) return max_level
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: """Creates a parameter if it doesn't exist yet in this scope and returns it. Args: name: the name of the parameter. init_fn: a function taking a PRNGKey plus any other number of positional arguments. *init_args: the arguments to evaluate init_fn on lazily. Returns: The parameters. """ self.reserve(name) if self.has_variable('params', name): abs_rng = jax.ShapeDtypeStruct((2, ), jnp.uint32) value = self.get_variable('params', name) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up # in a Flax Module match the shapes coming in during apply, and if not, # catch it with an error message. # NOTE: We could consider moving this to `self.` abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) abs_value_flat = jax.tree_leaves(abs_value) value_flat = jax.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): # NOTE: We could check dtype consistency here as well but it's # usefuleness is less obvious. We might intentionally change the dtype # for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): raise ValueError( 'Inconsistent shapes between value and initializer ' f'for parameter "{name}" in "{self.path_text}": {jnp.shape(val)}, {jnp.shape(abs_val)}' ) return value else: if not self.is_mutable_collection('params'): raise ValueError( f'No parameter named "{name}" exists in "{self.path_text}".' ) value = init_fn(self.make_rng('params'), *init_args) self.put_variable('params', name, value) return value
def eval_step(model, state, batch, num_classes, flatten_input=True): eval_keys = ['inputs', 'targets'] (inputs, targets) = [batch.get(k, None) for k in eval_keys] if flatten_input: inputs = inputs.reshape(inputs.shape[0], -1) if jax.tree_leaves(state): state = jax.lax.pmean(state, 'batch') with nn.stateful(state, mutable=False): logits = model(inputs, train=False) return compute_metrics(logits, targets, num_classes, weights=None)
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray: """Compute the loss of the network, including L2.""" logits = net.apply(params, batch) labels = jax.nn.one_hot(batch["label"], 10) l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) softmax_xent /= labels.shape[0] return softmax_xent + 1e-4 * l2_loss
def adaptive_hmc_update(state, log_prob, state_grad, key, step_size, trajectory_len, target_accept_rate=0.8, step_size_adaptation_speed=0.05, max_n_leapfrog=1000, jitter_amt=0.2): normal_key, uniform_key, jitter_key = jax.random.split(key, 3) n_leapfrog = jnp.array(jnp.ceil(trajectory_len / step_size), jnp.int32) n_leapfrog = jnp.minimum(n_leapfrog, max_n_leapfrog) jittered_step_size = step_size * jnp.exp( jnp.where( jnp.logical_or(step_size_adaptation_speed <= 0, target_accept_rate <= 0), jnp.log(1. + jitter_amt) * (2 * jax.random.uniform(jitter_key, ()) - 1.), 0.)) num_leaves = len(jax.tree_leaves(state)) normal_keys = list(jax.random.split(normal_key, num_leaves)) treedef = jax.tree_structure(state) normal_keys = jax.tree_unflatten(treedef, normal_keys) momentum = jax.tree_multimap( lambda s, key: jax.random.normal(key, s.shape), state, normal_keys) initial_energy = get_kinetic_energy(momentum) - log_prob new_state, new_momentum, new_grad, new_log_prob = leapfrog( jittered_step_size, n_leapfrog, state, momentum, state_grad) new_energy = _nan_to_inf( get_kinetic_energy(new_momentum) - new_log_prob) energy_diff = initial_energy - new_energy accept_prob = jnp.minimum(1., jnp.exp(energy_diff)) # TODO(izmailovpavel): check why the second condition is needed. accepted = jnp.logical_and( jax.random.uniform(uniform_key, log_prob.shape) < accept_prob, jnp.isfinite(energy_diff)) step_size = step_size * jnp.exp( jnp.where( jnp.logical_or(target_accept_rate <= 0, step_size_adaptation_speed <= 0), 0., step_size_adaptation_speed * (jnp.mean(accept_prob) - target_accept_rate))) state = jax.lax.cond(accepted, _first, _second, (new_state, state)) log_prob = jnp.where(accepted, new_log_prob, log_prob) state_grad = jax.lax.cond(accepted, _first, _second, (new_grad, state_grad)) return state, log_prob, state_grad, step_size, accept_prob
def _aggregate_nodes_to_globals(graph, node_features): n_graph = graph.n_node.shape[0] sum_n_node = jax.tree_leaves(graph.nodes)[0].shape[0] graph_idx = jnp.arange(n_graph) node_gr_idx = jnp.repeat(graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node) return jax.ops.segment_sum(node_features, node_gr_idx, num_segments=n_graph)
def update_fn(updates, state, params=None): count_inc = numerics.safe_int32_increment(state.count) dtype = getattr(next(iter(jax.tree_leaves(updates)), None), 'dtype', None) hparams = {k: _convert_floats(v, dtype) for k, v in state.hyperparams.items()} hparams.update(schedule_fn(count_inc, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update( updates, state.inner_state, params) # pylint:disable=too-many-function-args return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
def test_ParameterCount(self, model_name: str): # Parameter count from the autoaugment paper models, 100 classes: reference_parameter_count = { 'WideResnet28x10': 36278324, 'WideResnet28x6_ShakeShake': 26227572, 'Pyramid_ShakeDrop': 26288692, } model, _ = load_model.get_model(model_name, 1, 32, 100) parameter_count = sum(np.prod(e.shape) for e in jax.tree_leaves(model)) self.assertEqual(parameter_count, reference_parameter_count[model_name])
def inner(scope_fun, repack_fun, variable_groups, rng_groups, *args): nonlocal scope_fn, repack_fn try: scope_fn = scope_fun repack_fn = repack_fun scopes = jax.tree_leaves(scope_fn(variable_groups, rng_groups)) mutable = tuple( _hashable_filter(scope.mutable) for scope in scopes) return jitted(mutable, variable_groups, rng_groups, *args) finally: scope_fn, repack_fn = None, None
def _broadcast_global_to_nodes( global_feature: jnp.ndarray, graph: jraph.GraphsTuple, ) -> jnp.ndarray: graph_idx = jnp.arange(graph.n_node.shape[0]) sum_n_node = jax.tree_leaves(graph.nodes)[0].shape[0] node_graph_idx = jnp.repeat(graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node) return global_feature[node_graph_idx]
def init(self, master_rng, data, params=None, network_state=None, replicated_params=False): """Initializes state of the updater.""" data = self._preprocess(data) rngs = np.array([master_rng] * self._num_devices) if not replicated_params and params is not None: params = jax.tree_map(lambda x: np.array([x] * self._num_devices), params) state = self._init_fn(rngs, params, network_state, data) state['step'] = np.array(0, dtype=np.int64) # Wait for initialization to finish before starting training to keep # memory usage low. flat_params = jax.tree_leaves(state['params']) if flat_params: jax.tree_leaves(state['params'])[0].block_until_ready() return state
def loss_fn(model): """loss function used for training.""" with nn.stateful(state.model_state) as new_model_state: logits = model(batch['image']) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_leaves(model.params) weight_decay = 0.0001 weight_l2 = sum( [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss * FLAGS.loss_scaling, (new_model_state, logits)