def dist_params_structure(self): r""" The tree structure of the distribution parameters. """ return jax.tree_structure(self.default_priors)
def init_weights_and_state(self, input_signature): super().init_weights_and_state(input_signature) if self.init_checkpoint is None: return print('Loading pre-trained weights from', self.init_checkpoint) ckpt = tf.train.load_checkpoint(self.init_checkpoint) def reshape_qkv(name): x = ckpt.get_tensor(name) return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1) def reshape_o(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64, x.shape[-1])) def reshape_bias(name): x = ckpt.get_tensor(name) return x.reshape((-1, 64)) new_w = [ ckpt.get_tensor('bert/embeddings/word_embeddings'), ckpt.get_tensor('bert/embeddings/token_type_embeddings'), ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...], ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'), ckpt.get_tensor('bert/embeddings/LayerNorm/beta'), ] for i in range(12): # 12 layers new_w += [ reshape_qkv( f'bert/encoder/layer_{i}/attention/self/query/kernel'), reshape_qkv( f'bert/encoder/layer_{i}/attention/self/key/kernel'), reshape_qkv( f'bert/encoder/layer_{i}/attention/self/value/kernel'), reshape_o( f'bert/encoder/layer_{i}/attention/output/dense/kernel'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/query/bias'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/key/bias'), reshape_bias( f'bert/encoder/layer_{i}/attention/self/value/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/dense/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma' ), ckpt.get_tensor( f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'), ckpt.get_tensor( f'bert/encoder/layer_{i}/intermediate/dense/kernel'), ckpt.get_tensor( f'bert/encoder/layer_{i}/intermediate/dense/bias'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'), ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'), ckpt.get_tensor( f'bert/encoder/layer_{i}/output/LayerNorm/gamma'), ckpt.get_tensor( f'bert/encoder/layer_{i}/output/LayerNorm/beta'), ] new_w += [ ckpt.get_tensor('bert/pooler/dense/kernel'), ckpt.get_tensor('bert/pooler/dense/bias'), ] for a, b in zip(fastmath.tree_leaves(self.weights), new_w): assert a.shape == b.shape, ( f'Expected shape {a.shape}, got shape {b.shape}') self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w) move_to_device = jax.jit(lambda x: x) self.weights = jax.tree_map(move_to_device, self.weights)
def get_output_treedef() -> Box: rng = jax.random.PRNGKey(42) # This is fine, see above fns = hk.transform_with_state(lambda: f()[1]) apply_fns, _ = fns.apply(*fns.init(rng), rng) return Box(jax.tree_structure(apply_fns))
def params(self, new_params): if jax.tree_structure(new_params) != jax.tree_structure(self.params): raise TypeError( "new params must have the same structure as old params") self.temperature = new_params['temperature'] self.q.params = new_params['q']
def restore_checkpoint(ckpt_dir, target, sharded_match_fn, step = None, process_id = None, process_count = None): """Restores the last checkpoint from checkpoints in path, or a specific one. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3. Important: Like Flax's restore_checkpoint(), this function doesn't handle replication or sharding of parameters. Once a checkpoint is loaded, call: target = core_utils.tree_replicate_by_name(target, not_sharded_match_fn) target = core_utils.tree_shard_by_name(target, sharded_match_fn) to replicated and sharded the relevant parameters. Args: ckpt_dir: Directory of checkpoints to restore from. target: Serializable Flax object, usually a Flax optimizer. sharded_match_fn: Function that returns true if a given parameter name corresponds to that of a sharded parameter. If no sharded match function is given, we only attempt to load replicated parameters from the ckpt_dir. step: Training step number. If None, restores the last one. process_id: Identifier for process saving the checkpoint. If None (default), uses jax.process_index(). process_count: Total number of processes in the system. If None (default), uses jax.process_count(). Returns: Restored target updated from checkpoint file. If no step is given and no checkpoints can be found, returns None. """ process_id = process_id or jax.process_index() process_count = process_count or jax.process_count() # Restore parameters to replicate across all devices. target_to_replicate = checkpoints.restore_checkpoint( ckpt_dir=ckpt_dir, target=target, step=step, prefix=_replicated_checkpoint_pattern()) if target_to_replicate is target: logging.info("No replicate checkpoint found: returning None.") return None if sharded_match_fn is None: # Treat all parameters as replicated; don't attempt to restore any sharded # parameters. return target_to_replicate target_to_shard = checkpoints.restore_checkpoint( ckpt_dir=ckpt_dir, target=target, step=step, prefix=_sharded_checkpoint_pattern(process_id, process_count)) if target_to_shard is target: logging.info("No sharded checkpoint found: returning None.") return None if target is None: target = target_to_replicate treedef = jax.tree_structure(target) names = [name for name, _ in core_utils.tree_flatten_with_names(target)[0]] values_to_replicate = jax.tree_leaves(target_to_replicate) values_to_shard = jax.tree_leaves(target_to_shard) target = jax.tree_unflatten(treedef, [ vs if sharded_match_fn(name) else vr for name, vr, vs in zip(names, values_to_replicate, values_to_shard) ]) target = jax.tree_map(_recover_bfloat16_dtype, target) return target
def optimizer(self, new_optimizer): new_optimizer_state_structure = jax.tree_structure(new_optimizer.init(self._f.params)) if new_optimizer_state_structure != jax.tree_structure(self.optimizer_state): raise AttributeError("cannot set optimizer attr: mismatch in optimizer_state structure") self._optimizer = new_optimizer
def __init__(self, tagged_func: Callable[[Any], jnp.ndarray], func_args: Sequence[Any], l2_reg: Union[float, jnp.ndarray], estimation_mode: str = "fisher_gradients", params_index: int = 0, layer_tag_to_block_cls: Optional[TagMapping] = None): """Create a FisherEstimator object. Args: tagged_func: The function which evaluates the model, in which layer and loss tags has already been registered. func_args: Arguments to trace the function for layer and loss tags. l2_reg: Scalar. The L2 regularization coefficient, which represents the following regularization function: `coefficient/2 ||theta||^2`. estimation_mode: The type of curvature estimator to use. One of: * 'fisher_gradients' - the basic estimation approach from the original K-FAC paper. (Default) * 'fisher_curvature_prop' - method which estimates the Fisher using self-products of random 1/-1 vectors times "half-factors" of the Fisher, as described here: https://arxiv.org/abs/1206.6464 * 'fisher_exact' - is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. * 'fisher_empirical' - computes the 'empirical' Fisher information matrix (which uses the data's distribution for the targets, as opposed to the true Fisher which uses the model's distribution) and requires that each registered loss have specified targets. * 'ggn_curvature_prop' - Analogous to fisher_curvature_prop, but estimates the Generalized Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to fisher_exact, but estimates the Generalized Gauss-Newton matrix (GGN). params_index: The index of the arguments accepted by `func` which correspond to parameters. layer_tag_to_block_cls: An optional dict mapping tags to specific classes of block approximations, which to override the default ones. """ if estimation_mode not in ("fisher_gradients", "fisher_empirical", "fisher_exact", "fisher_curvature_prop", "ggn_exact", "ggn_curvature_prop"): raise ValueError( f"Unrecognised estimation_mode={estimation_mode}.") super().__init__() self.tagged_func = tagged_func self.l2_reg = l2_reg self.estimation_mode = estimation_mode self.params_index = params_index self.vjp = tracer.trace_estimator_vjp(self.tagged_func) # Figure out the mapping from layer self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block( ) if layer_tag_to_block_cls is None: layer_tag_to_block_cls = dict() layer_tag_to_block_cls = dict(**layer_tag_to_block_cls) self.layer_tag_to_block_cls.update(layer_tag_to_block_cls) # Create the blocks self._in_tree = jax.tree_structure(func_args) self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr) self.blocks = collections.OrderedDict() counters = dict() for eqn in self._layer_tags: cls = self.layer_tag_to_block_cls[eqn.primitive.name] c = counters.get(cls.__name__, 0) self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn) counters[cls.__name__] = c + 1
def function_state(self, new_function_state): if jax.tree_structure(new_function_state) != jax.tree_structure(self._function_state): raise TypeError("new function_state must have the same structure as old function_state") self._function_state = new_function_state
def do_trees_have_same_structure(a, b): """Returns True if two jax trees have the same structure. """ return jax.tree_structure(a) == jax.tree_structure(b)
def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None: """Verifies that the two objects have the same pytree structure.""" assert jax.tree_structure(obj1) == jax.tree_structure(obj2) for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]): assert v1.shape == v2.shape assert v1.dtype == v2.dtype
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fast decoding routines for inference from a trained model.""" from typing import Callable, Mapping, Optional, Tuple import flax import jax from jax import lax from jax import random import jax.numpy as jnp import numpy as np PyTreeDef = type(jax.tree_structure(None)) SamplingLoopState = Tuple[int, jnp.ndarray, Mapping[str, jnp.ndarray], jnp.ndarray, jnp.ndarray, jnp.ndarray] # Constants # We assume the default End-of-Sentence token id is 1 (T5 convention). EOS_ID = 1 # "Effective negative infinity" constant for masking in beam search. NEG_INF = np.array(-1.0e7) #------------------------------------------------------------------------------ # Temperature Sampling #------------------------------------------------------------------------------ def temperature_sample(prompt_inputs,
def inner_product(obj1: T, obj2: T) -> jnp.ndarray: if jax.tree_structure(obj1) != jax.tree_structure(obj2): raise ValueError("The two structures are not identical.") elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2) return sum(jax.tree_flatten(elements_product)[0])
def make( factors: Iterable[FactorBase], use_onp: bool = True, ) -> "StackedFactorGraph": """Create a factor graph from a set of factors.""" # Start by grouping our factors and grabbing a list of (ordered!) variables factors_from_group: DefaultDict[GroupKey, List[FactorBase]] = defaultdict(list) variables_ordered_set: Dict[VariableBase, None] = {} for factor in factors: # Each factor is ultimately just a pytree node; in order for a set of # factors to be batchable, they must share the same: group_key: GroupKey = ( # (1) Treedef. Note that variables can be different as long as their # types are the same. jax.tree_structure(factor.anonymize_variables()), # (2) Leaf shapes: contained array shapes must match tuple( leaf.shape if hasattr(leaf, "shape") else () for leaf in jax.tree_leaves(factor) ), ) # Record factor and variables factors_from_group[group_key].append(factor) for v in factor.variables: variables_ordered_set[v] = None variables = list(variables_ordered_set.keys()) # Fields we want to populate stacked_factors: List[FactorStack] = [] jacobian_coords: List[sparse.SparseCooCoordinates] = [] # Create storage layout: this describes which parts of our storage object is # allocated to each variable storage_layout = StorageLayout.make(variables, local=False) local_storage_layout = StorageLayout.make(variables, local=True) # Prepare each factor group residual_offset = 0 for group_key, group in factors_from_group.items(): # Make factor stack stacked_factors.append( FactorStack.make( group, storage_layout, use_onp=use_onp, ) ) # Compute Jacobian coordinates # # These should be N pairs of (row, col) indices, where rows correspond to # residual indices and columns correspond to local parameter indices jacobian_coords.extend( FactorStack.compute_jacobian_coords( factors=group, local_storage_layout=local_storage_layout, row_offset=residual_offset, ) ) residual_offset += stacked_factors[-1].get_residual_dim() jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_map( lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords ) return StackedFactorGraph( factor_stacks=stacked_factors, jacobian_coords=jacobian_coords_concat, storage_layout=storage_layout, local_storage_layout=local_storage_layout, residual_dim=residual_offset, )
def func(params, state, rng, S, is_training): rngs = hk.PRNGSequence(rng) new_state = dict(state) # s' ~ p(s'|s,.) # note: S_next is replicated, one for each (discrete) action if is_stochastic(self.p): dist_params_rep, new_state['p'] = self.p.function_type2( params['p'], state['p'], next(rngs), S, is_training) dist_params_rep = jax.tree_map(self._reshape_to_replicas, dist_params_rep) S_next_rep = self.p.proba_dist.sample( dist_params_rep, next(rngs)) else: S_next_rep, new_state['p'] = self.p.function_type2( params['p'], state['p'], next(rngs), S, is_training) S_next_rep = jax.tree_map(self._reshape_to_replicas, S_next_rep) # r ~ p(r|s,a) # note: R is replicated, one for each (discrete) action if is_stochastic(self.r): dist_params_rep, new_state['r'] = self.r.function_type2( params['r'], state['r'], next(rngs), S, is_training) dist_params_rep = jax.tree_map(self._reshape_to_replicas, dist_params_rep) R_rep = self.r.proba_dist.sample(dist_params_rep, next(rngs)) R_rep = self.r.proba_dist.postprocess_variate( next(rngs), R_rep, batch_mode=True) else: R_rep, new_state['r'] = self.r.function_type2( params['r'], state['r'], next(rngs), S, is_training) R_rep = jax.tree_map(self._reshape_to_replicas, R_rep) # v(s') # note: since the input S_next is replicated, so is the output V if is_stochastic(self.v): dist_params_rep, new_state['v'] = self.v.function( params['v'], state['v'], next(rngs), S_next_rep, is_training) V_rep = self.v.proba_dist.sample(dist_params_rep, next(rngs)) V_rep = self.v.proba_dist.postprocess_variate( next(rngs), V_rep, batch_mode=True) else: V_rep, new_state['v'] = self.v.function( params['v'], state['v'], next(rngs), S_next_rep, is_training) # q = r + γ v(s') f, f_inv = self.value_transform Q_rep = f(R_rep + params['gamma'] * f_inv(V_rep)) # reshape from (batch x num_actions, *) to (batch, num_actions, *) Q_s = self._reshape_from_replicas(Q_rep) assert Q_s.ndim == 2, f"bad shape: {Q_s.shape}" assert Q_s.shape[ 1] == self.action_space.n, f"bad shape: {Q_s.shape}" new_state = hk.data_structures.to_immutable_dict(new_state) assert jax.tree_structure(new_state) == jax.tree_structure( state) return Q_s, new_state
def random_split_like_tree(rng_key, target=None, treedef=None): if treedef is None: treedef = jax.tree_structure(target) keys = jax.random.split(rng_key, treedef.num_leaves) return jax.tree_unflatten(treedef, keys)
def __call__(self, inputs, state): """Run one step of the wrapped core, handling state reset. Args: inputs: Tuple with two elements, ``inputs, should_reset``, where ``should_reset`` is the signal used to reset the wrapped core's state. ``should_reset`` can be either tensor or nest. If nest, ``should_reset`` must match the state structure, and its components' shapes must be prefixes of the corresponding entries tensors' shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g. ``[batch_size]``. state: Previous wrapped core state. Returns: Tuple of the wrapped core's ``output, next_state``. """ inputs, should_reset = inputs if jax.treedef_is_leaf(jax.tree_structure(should_reset)): # Equivalent to not tree.is_nested, but with support for Jax extensible # pytrees. should_reset = jax.tree_map(lambda _: should_reset, state) # We now need to manually pad 'on the right' to ensure broadcasting operates # correctly. # Automatic broadcasting would in fact implicitly pad 'on the left', # resulting in the signal to trigger resets for parts of the state # across batch entries. For example: # # import jax # import jax.numpy as jnp # # shape = (2, 2, 2) # x = jnp.zeros(shape) # y = jnp.ones(shape) # should_reset = jnp.array([False, True]) # v = jnp.where(should_reset, x, y) # for batch_entry in range(shape[0]): # print("batch_entry {}:\n".format(batch_entry), v[batch_entry]) # # >> batch_entry 0: # >> [[1. 0.] # >> [1. 0.]] # >> batch_entry 1: # >> [[1. 0.] # >> [1. 0.]] # # Note how manually padding the should_reset tensor yields the desired # behavior. # # import jax # import jax.numpy as jnp # # shape = (2, 2, 2) # x = jnp.zeros(shape) # y = jnp.ones(shape) # should_reset = jnp.array([False, True]) # dims_to_add = x.ndim - should_reset.ndim # should_reset = should_reset.reshape(should_reset.shape + (1,)*dims_to_add) # v = jnp.where(should_reset, x, y) # for batch_entry in range(shape[0]): # print("batch_entry {}:\n".format(batch_entry), v[batch_entry]) # # >> batch_entry 0: # >> [[1. 1.] # >> [1. 1.]] # >> batch_entry 1: # >> [[0. 0.] # >> [0. 0.]] should_reset = jax.tree_multimap(_validate_and_conform, should_reset, state) if self._is_batched(state): batch_size = jax.tree_leaves(inputs)[0].shape[0] else: batch_size = None initial_state = jax.tree_multimap(lambda s, i: i.astype(s.dtype), state, self.initial_state(batch_size)) state = jax.tree_multimap(jnp.where, should_reset, initial_state, state) return self.core(inputs, state)
def optimizer_state(self, new_optimizer_state): if jax.tree_structure(new_optimizer_state) != jax.tree_structure( self.optimizer_state): raise AttributeError( "cannot set optimizer_state attr: mismatch in tree structure") self._optimizer_state = new_optimizer_state
def params(self, new_params): if jax.tree_structure(new_params) != jax.tree_structure(self._params): raise TypeError( "new params must have the same structure as old params") self._params = new_params
def catch_treedef(scopes, *args): treedef = jax.tree_structure(scopes) return wrapper(scopes, treedef, *args)
def make( factors: Iterable[FactorBase], use_onp: bool = True, ) -> "StackedFactorGraph": """Create a factor graph from a set of factors.""" # Start by grouping our factors and grabbing a list of (ordered!) variables factors_from_group: DefaultDict[GroupKey, List[FactorBase]] = defaultdict(list) variables_ordered_set: Dict[VariableBase, None] = {} for factor in factors: # Each factor is ultimately just a PyTree node; in order for a set of # factors to be batchable, they must share the same: group_key: GroupKey = ( # (1) Tree structure: this encompasses the factor class, variable types jax.tree_structure(factor), # (2) Leaf shapes: array shapes must match in order to be stacked tuple(leaf.shape for leaf in jax.tree_leaves(factor)), ) # Record factor and variables factors_from_group[group_key].append(factor) for v in factor.variables: variables_ordered_set[v] = None variables = list(variables_ordered_set.keys()) # Make sure factors are unique for factors in factors_from_group.values(): assert len(factors) == len(set(factors)) # Fields we want to populate stacked_factors: List[FactorStack] = [] jacobian_coords: List[sparse.SparseCooCoordinates] = [] # Create storage metadata: this determines which parts of our storage object is # allocated to each variable type storage_metadata = StorageMetadata.make(variables, local=False) local_storage_metadata = StorageMetadata.make(variables, local=True) # Prepare each factor group residual_offset = 0 for group_key, group in factors_from_group.items(): # Make factor stack stacked_factors.append( FactorStack.make( group, storage_metadata, use_onp=use_onp, )) # Compute Jacobian coordinates # # These should be N pairs of (row, col) indices, where rows correspond to # residual indices and columns correspond to local parameter indices jacobian_coords.extend( FactorStack.compute_jacobian_coords( factors=group, local_storage_metadata=local_storage_metadata, row_offset=residual_offset, )) residual_offset += stacked_factors[-1].get_residual_dim() jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_multimap( lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords) return StackedFactorGraph( factor_stacks=stacked_factors, jacobian_coords=jacobian_coords_concat, local_storage_metadata=local_storage_metadata, residual_dim=residual_offset, )