def test_meta_loss(self, model_name): """Test that meta_init does not update the bias scalars.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, _ = _load_model(model_name) norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)), params) normalized_params = jax.tree_map(meta_init.normalize, params) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) learned_norms, _ = meta_init.meta_optimize_scales( loss_fn=loss_fn, fprop=flax_module.apply, normalized_params=normalized_params, norms=norms, hps=meta_init.DEFAULT_HPARAMS, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, rng_key=rng) # Check that all learned bias scales are 0, the meta loss should be # independent of these terms. learned_norms_flat = model_utils.flatten_dict(learned_norms) for layer_key in learned_norms_flat: if 'bias' in layer_key: self.assertEqual(learned_norms_flat[layer_key], 0.0)
def get_summary_tree(training_metrics_grabber): """Extracts desired training statistics from the grabber state. Currently this function will compute the scalar aggregate gradient variance for every weight matrix of the model. Future iterations of this function may depend on the metrics_grabber config. Args: training_metrics_grabber: TrainingMetricsGrabber object. Returns: A dict of different aggregate training statistics. """ unreplicated_metrics_tree = jax.tree_map(lambda x: x[0], training_metrics_grabber.state) # Example key: Layer1/conv1/kernel/ # NOTE: jax.tree_map does not work here, because tree_map will additionally # flatten the node state, while model_utils.flatten_dict will consider the # node object a leaf. flat_metrics = model_utils.flatten_dict(unreplicated_metrics_tree) # Grab just the gradient_variance terms. def _reduce_node(node): # Var[g] = E[g^2] - E[g]^2 grad_var_ema = node.grad_sq_ema - jnp.square(node.grad_ema) update_var_ema = node.update_sq_ema - jnp.square(node.update_ema) return { 'grad_var': grad_var_ema.sum(), 'param_norm': node.param_norm, 'update_var': update_var_ema.sum(), 'update_ratio': update_var_ema.sum() / node.param_norm, } return {k: _reduce_node(flat_metrics[k]) for k in flat_metrics}
def _get_non_bias_params(params): flat_params = model_utils.flatten_dict(params) bias_and_scalar_keys = [ key for key in flat_params if len(flat_params[key].shape) >= 2 ] return bias_and_scalar_keys