Example #1
0
 def dist_params_structure(self):
     r""" The tree structure of the distribution parameters. """
     return jax.tree_structure(self.default_priors)
Example #2
0
    def init_weights_and_state(self, input_signature):
        super().init_weights_and_state(input_signature)
        if self.init_checkpoint is None:
            return

        print('Loading pre-trained weights from', self.init_checkpoint)
        ckpt = tf.train.load_checkpoint(self.init_checkpoint)

        def reshape_qkv(name):
            x = ckpt.get_tensor(name)
            return x.reshape((x.shape[0], -1, 64)).swapaxes(0, 1)

        def reshape_o(name):
            x = ckpt.get_tensor(name)
            return x.reshape((-1, 64, x.shape[-1]))

        def reshape_bias(name):
            x = ckpt.get_tensor(name)
            return x.reshape((-1, 64))

        new_w = [
            ckpt.get_tensor('bert/embeddings/word_embeddings'),
            ckpt.get_tensor('bert/embeddings/token_type_embeddings'),
            ckpt.get_tensor('bert/embeddings/position_embeddings')[None, ...],
            ckpt.get_tensor('bert/embeddings/LayerNorm/gamma'),
            ckpt.get_tensor('bert/embeddings/LayerNorm/beta'),
        ]

        for i in range(12):  # 12 layers
            new_w += [
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/query/kernel'),
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/key/kernel'),
                reshape_qkv(
                    f'bert/encoder/layer_{i}/attention/self/value/kernel'),
                reshape_o(
                    f'bert/encoder/layer_{i}/attention/output/dense/kernel'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/query/bias'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/key/bias'),
                reshape_bias(
                    f'bert/encoder/layer_{i}/attention/self/value/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/dense/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'
                ),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/intermediate/dense/kernel'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/intermediate/dense/bias'),
                ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/kernel'),
                ckpt.get_tensor(f'bert/encoder/layer_{i}/output/dense/bias'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/output/LayerNorm/gamma'),
                ckpt.get_tensor(
                    f'bert/encoder/layer_{i}/output/LayerNorm/beta'),
            ]

        new_w += [
            ckpt.get_tensor('bert/pooler/dense/kernel'),
            ckpt.get_tensor('bert/pooler/dense/bias'),
        ]

        for a, b in zip(fastmath.tree_leaves(self.weights), new_w):
            assert a.shape == b.shape, (
                f'Expected shape {a.shape}, got shape {b.shape}')
        self.weights = jax.tree_unflatten(jax.tree_structure(self.weights),
                                          new_w)
        move_to_device = jax.jit(lambda x: x)
        self.weights = jax.tree_map(move_to_device, self.weights)
Example #3
0
 def get_output_treedef() -> Box:
     rng = jax.random.PRNGKey(42)  # This is fine, see above
     fns = hk.transform_with_state(lambda: f()[1])
     apply_fns, _ = fns.apply(*fns.init(rng), rng)
     return Box(jax.tree_structure(apply_fns))
Example #4
0
 def params(self, new_params):
     if jax.tree_structure(new_params) != jax.tree_structure(self.params):
         raise TypeError(
             "new params must have the same structure as old params")
     self.temperature = new_params['temperature']
     self.q.params = new_params['q']
def restore_checkpoint(ckpt_dir,
                       target,
                       sharded_match_fn,
                       step = None,
                       process_id = None,
                       process_count = None):
  """Restores the last checkpoint from checkpoints in path, or a specific one.

  Sorts the checkpoint files naturally, returning the highest-valued file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3.

  Important: Like Flax's restore_checkpoint(), this function doesn't handle
  replication or sharding of parameters. Once a checkpoint is loaded, call:
  target = core_utils.tree_replicate_by_name(target, not_sharded_match_fn)
  target = core_utils.tree_shard_by_name(target, sharded_match_fn)
  to replicated and sharded the relevant parameters.

  Args:
    ckpt_dir: Directory of checkpoints to restore from.
    target: Serializable Flax object, usually a Flax optimizer.
    sharded_match_fn: Function that returns true if a given parameter name
      corresponds to that of a sharded parameter. If no sharded match function
      is given, we only attempt to load replicated parameters from the ckpt_dir.
    step: Training step number. If None, restores the last one.
    process_id: Identifier for process saving the checkpoint. If None (default),
      uses jax.process_index().
    process_count: Total number of processes in the system. If None (default),
      uses jax.process_count().

  Returns:
    Restored target updated from checkpoint file. If no step is given and no
    checkpoints can be found, returns None.
  """
  process_id = process_id or jax.process_index()
  process_count = process_count or jax.process_count()

  # Restore parameters to replicate across all devices.
  target_to_replicate = checkpoints.restore_checkpoint(
      ckpt_dir=ckpt_dir,
      target=target,
      step=step,
      prefix=_replicated_checkpoint_pattern())
  if target_to_replicate is target:
    logging.info("No replicate checkpoint found: returning None.")
    return None

  if sharded_match_fn is None:
    # Treat all parameters as replicated; don't attempt to restore any sharded
    # parameters.
    return target_to_replicate

  target_to_shard = checkpoints.restore_checkpoint(
      ckpt_dir=ckpt_dir,
      target=target,
      step=step,
      prefix=_sharded_checkpoint_pattern(process_id, process_count))
  if target_to_shard is target:
    logging.info("No sharded checkpoint found: returning None.")
    return None

  if target is None:
    target = target_to_replicate
  treedef = jax.tree_structure(target)
  names = [name for name, _ in core_utils.tree_flatten_with_names(target)[0]]
  values_to_replicate = jax.tree_leaves(target_to_replicate)
  values_to_shard = jax.tree_leaves(target_to_shard)
  target = jax.tree_unflatten(treedef, [
      vs if sharded_match_fn(name) else vr
      for name, vr, vs in zip(names, values_to_replicate, values_to_shard)
  ])

  target = jax.tree_map(_recover_bfloat16_dtype, target)

  return target
Example #6
0
 def optimizer(self, new_optimizer):
     new_optimizer_state_structure = jax.tree_structure(new_optimizer.init(self._f.params))
     if new_optimizer_state_structure != jax.tree_structure(self.optimizer_state):
         raise AttributeError("cannot set optimizer attr: mismatch in optimizer_state structure")
     self._optimizer = new_optimizer
    def __init__(self,
                 tagged_func: Callable[[Any], jnp.ndarray],
                 func_args: Sequence[Any],
                 l2_reg: Union[float, jnp.ndarray],
                 estimation_mode: str = "fisher_gradients",
                 params_index: int = 0,
                 layer_tag_to_block_cls: Optional[TagMapping] = None):
        """Create a FisherEstimator object.

    Args:
      tagged_func: The function which evaluates the model, in which layer and
        loss tags has already been registered.
      func_args: Arguments to trace the function for layer and loss tags.
      l2_reg: Scalar. The L2 regularization coefficient, which represents
          the following regularization function: `coefficient/2 ||theta||^2`.
      estimation_mode: The type of curvature estimator to use. One of: *
        'fisher_gradients' - the basic estimation approach from the original
        K-FAC paper. (Default) * 'fisher_curvature_prop' - method which
        estimates the Fisher using self-products of random 1/-1 vectors times
        "half-factors" of the
              Fisher, as described here: https://arxiv.org/abs/1206.6464 *
                'fisher_exact' - is the obvious generalization of Curvature
                Propagation to compute the exact Fisher (modulo any additional
                diagonal or Kronecker approximations) by looping over one-hot
                vectors for each coordinate of the output instead of using 1/-1
                vectors. It is more expensive to compute than the other three
                options by a factor equal to the output dimension, roughly
                speaking. * 'fisher_empirical' - computes the 'empirical' Fisher
                information matrix (which uses the data's distribution for the
                targets, as opposed to the true Fisher which uses the model's
                distribution) and requires that each registered loss have
                specified targets. * 'ggn_curvature_prop' - Analogous to
                fisher_curvature_prop, but estimates the Generalized
                Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to
                fisher_exact, but estimates the Generalized Gauss-Newton matrix
                (GGN).
      params_index: The index of the arguments accepted by `func` which
        correspond to parameters.
      layer_tag_to_block_cls: An optional dict mapping tags to specific classes
        of block approximations, which to override the default ones.
    """
        if estimation_mode not in ("fisher_gradients", "fisher_empirical",
                                   "fisher_exact", "fisher_curvature_prop",
                                   "ggn_exact", "ggn_curvature_prop"):
            raise ValueError(
                f"Unrecognised estimation_mode={estimation_mode}.")
        super().__init__()
        self.tagged_func = tagged_func
        self.l2_reg = l2_reg
        self.estimation_mode = estimation_mode
        self.params_index = params_index
        self.vjp = tracer.trace_estimator_vjp(self.tagged_func)

        # Figure out the mapping from layer
        self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block(
        )
        if layer_tag_to_block_cls is None:
            layer_tag_to_block_cls = dict()
        layer_tag_to_block_cls = dict(**layer_tag_to_block_cls)
        self.layer_tag_to_block_cls.update(layer_tag_to_block_cls)

        # Create the blocks
        self._in_tree = jax.tree_structure(func_args)
        self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr
        self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr)
        self.blocks = collections.OrderedDict()
        counters = dict()
        for eqn in self._layer_tags:
            cls = self.layer_tag_to_block_cls[eqn.primitive.name]
            c = counters.get(cls.__name__, 0)
            self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn)
            counters[cls.__name__] = c + 1
Example #8
0
 def function_state(self, new_function_state):
     if jax.tree_structure(new_function_state) != jax.tree_structure(self._function_state):
         raise TypeError("new function_state must have the same structure as old function_state")
     self._function_state = new_function_state
Example #9
0
File: util.py Project: byzhang/d3p
def do_trees_have_same_structure(a, b):
    """Returns True if two jax trees have the same structure.
    """
    return jax.tree_structure(a) == jax.tree_structure(b)
Example #10
0
def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None:
    """Verifies that the two objects have the same pytree structure."""
    assert jax.tree_structure(obj1) == jax.tree_structure(obj2)
    for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]):
        assert v1.shape == v2.shape
        assert v1.dtype == v2.dtype
Example #11
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fast decoding routines for inference from a trained model."""

from typing import Callable, Mapping, Optional, Tuple
import flax
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np

PyTreeDef = type(jax.tree_structure(None))
SamplingLoopState = Tuple[int, jnp.ndarray, Mapping[str, jnp.ndarray],
                          jnp.ndarray, jnp.ndarray, jnp.ndarray]

# Constants
# We assume the default End-of-Sentence token id is 1 (T5 convention).
EOS_ID = 1
# "Effective negative infinity" constant for masking in beam search.
NEG_INF = np.array(-1.0e7)

#------------------------------------------------------------------------------
# Temperature Sampling
#------------------------------------------------------------------------------


def temperature_sample(prompt_inputs,
Example #12
0
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
    if jax.tree_structure(obj1) != jax.tree_structure(obj2):
        raise ValueError("The two structures are not identical.")
    elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1,
                                         obj2)
    return sum(jax.tree_flatten(elements_product)[0])
Example #13
0
    def make(
        factors: Iterable[FactorBase],
        use_onp: bool = True,
    ) -> "StackedFactorGraph":
        """Create a factor graph from a set of factors."""

        # Start by grouping our factors and grabbing a list of (ordered!) variables
        factors_from_group: DefaultDict[GroupKey, List[FactorBase]] = defaultdict(list)
        variables_ordered_set: Dict[VariableBase, None] = {}
        for factor in factors:
            # Each factor is ultimately just a pytree node; in order for a set of
            # factors to be batchable, they must share the same:
            group_key: GroupKey = (
                # (1) Treedef. Note that variables can be different as long as their
                # types are the same.
                jax.tree_structure(factor.anonymize_variables()),
                # (2) Leaf shapes: contained array shapes must match
                tuple(
                    leaf.shape if hasattr(leaf, "shape") else ()
                    for leaf in jax.tree_leaves(factor)
                ),
            )

            # Record factor and variables
            factors_from_group[group_key].append(factor)
            for v in factor.variables:
                variables_ordered_set[v] = None
        variables = list(variables_ordered_set.keys())

        # Fields we want to populate
        stacked_factors: List[FactorStack] = []
        jacobian_coords: List[sparse.SparseCooCoordinates] = []

        # Create storage layout: this describes which parts of our storage object is
        # allocated to each variable
        storage_layout = StorageLayout.make(variables, local=False)
        local_storage_layout = StorageLayout.make(variables, local=True)

        # Prepare each factor group
        residual_offset = 0
        for group_key, group in factors_from_group.items():
            # Make factor stack
            stacked_factors.append(
                FactorStack.make(
                    group,
                    storage_layout,
                    use_onp=use_onp,
                )
            )

            # Compute Jacobian coordinates
            #
            # These should be N pairs of (row, col) indices, where rows correspond to
            # residual indices and columns correspond to local parameter indices
            jacobian_coords.extend(
                FactorStack.compute_jacobian_coords(
                    factors=group,
                    local_storage_layout=local_storage_layout,
                    row_offset=residual_offset,
                )
            )
            residual_offset += stacked_factors[-1].get_residual_dim()

        jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_map(
            lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords
        )

        return StackedFactorGraph(
            factor_stacks=stacked_factors,
            jacobian_coords=jacobian_coords_concat,
            storage_layout=storage_layout,
            local_storage_layout=local_storage_layout,
            residual_dim=residual_offset,
        )
Example #14
0
            def func(params, state, rng, S, is_training):
                rngs = hk.PRNGSequence(rng)
                new_state = dict(state)

                # s' ~ p(s'|s,.)  # note: S_next is replicated, one for each (discrete) action
                if is_stochastic(self.p):
                    dist_params_rep, new_state['p'] = self.p.function_type2(
                        params['p'], state['p'], next(rngs), S, is_training)
                    dist_params_rep = jax.tree_map(self._reshape_to_replicas,
                                                   dist_params_rep)
                    S_next_rep = self.p.proba_dist.sample(
                        dist_params_rep, next(rngs))
                else:
                    S_next_rep, new_state['p'] = self.p.function_type2(
                        params['p'], state['p'], next(rngs), S, is_training)
                    S_next_rep = jax.tree_map(self._reshape_to_replicas,
                                              S_next_rep)

                # r ~ p(r|s,a)  # note: R is replicated, one for each (discrete) action
                if is_stochastic(self.r):
                    dist_params_rep, new_state['r'] = self.r.function_type2(
                        params['r'], state['r'], next(rngs), S, is_training)
                    dist_params_rep = jax.tree_map(self._reshape_to_replicas,
                                                   dist_params_rep)
                    R_rep = self.r.proba_dist.sample(dist_params_rep,
                                                     next(rngs))
                    R_rep = self.r.proba_dist.postprocess_variate(
                        next(rngs), R_rep, batch_mode=True)
                else:
                    R_rep, new_state['r'] = self.r.function_type2(
                        params['r'], state['r'], next(rngs), S, is_training)
                    R_rep = jax.tree_map(self._reshape_to_replicas, R_rep)

                # v(s')  # note: since the input S_next is replicated, so is the output V
                if is_stochastic(self.v):
                    dist_params_rep, new_state['v'] = self.v.function(
                        params['v'], state['v'], next(rngs), S_next_rep,
                        is_training)
                    V_rep = self.v.proba_dist.sample(dist_params_rep,
                                                     next(rngs))
                    V_rep = self.v.proba_dist.postprocess_variate(
                        next(rngs), V_rep, batch_mode=True)
                else:
                    V_rep, new_state['v'] = self.v.function(
                        params['v'], state['v'], next(rngs), S_next_rep,
                        is_training)

                # q = r + γ v(s')
                f, f_inv = self.value_transform
                Q_rep = f(R_rep + params['gamma'] * f_inv(V_rep))

                # reshape from (batch x num_actions, *) to (batch, num_actions, *)
                Q_s = self._reshape_from_replicas(Q_rep)
                assert Q_s.ndim == 2, f"bad shape: {Q_s.shape}"
                assert Q_s.shape[
                    1] == self.action_space.n, f"bad shape: {Q_s.shape}"

                new_state = hk.data_structures.to_immutable_dict(new_state)
                assert jax.tree_structure(new_state) == jax.tree_structure(
                    state)

                return Q_s, new_state
Example #15
0
def random_split_like_tree(rng_key, target=None, treedef=None):
    if treedef is None:
        treedef = jax.tree_structure(target)
    keys = jax.random.split(rng_key, treedef.num_leaves)
    return jax.tree_unflatten(treedef, keys)
Example #16
0
    def __call__(self, inputs, state):
        """Run one step of the wrapped core, handling state reset.

    Args:
      inputs: Tuple with two elements, ``inputs, should_reset``, where
        ``should_reset`` is the signal used to reset the wrapped core's state.
        ``should_reset`` can be either tensor or nest. If nest, ``should_reset``
        must match the state structure, and its components' shapes must be
        prefixes of the corresponding entries tensors' shapes in the state nest.
        If tensor, supported shapes are all commom shape prefixes of the state
        component tensors, e.g. ``[batch_size]``.
      state: Previous wrapped core state.

    Returns:
      Tuple of the wrapped core's ``output, next_state``.
    """
        inputs, should_reset = inputs
        if jax.treedef_is_leaf(jax.tree_structure(should_reset)):
            # Equivalent to not tree.is_nested, but with support for Jax extensible
            # pytrees.
            should_reset = jax.tree_map(lambda _: should_reset, state)

        # We now need to manually pad 'on the right' to ensure broadcasting operates
        # correctly.
        # Automatic broadcasting would in fact implicitly pad 'on the left',
        # resulting in the signal to trigger resets for parts of the state
        # across batch entries. For example:
        #
        # import jax
        # import jax.numpy as jnp
        #
        # shape = (2, 2, 2)
        # x = jnp.zeros(shape)
        # y = jnp.ones(shape)
        # should_reset = jnp.array([False, True])
        # v = jnp.where(should_reset, x, y)
        # for batch_entry in range(shape[0]):
        #   print("batch_entry {}:\n".format(batch_entry), v[batch_entry])
        #
        # >> batch_entry 0:
        # >>  [[1. 0.]
        # >>  [1. 0.]]
        # >> batch_entry 1:
        # >>  [[1. 0.]
        # >>  [1. 0.]]
        #
        # Note how manually padding the should_reset tensor yields the desired
        # behavior.
        #
        # import jax
        # import jax.numpy as jnp
        #
        # shape = (2, 2, 2)
        # x = jnp.zeros(shape)
        # y = jnp.ones(shape)
        # should_reset = jnp.array([False, True])
        # dims_to_add = x.ndim - should_reset.ndim
        # should_reset = should_reset.reshape(should_reset.shape + (1,)*dims_to_add)
        # v = jnp.where(should_reset, x, y)
        # for batch_entry in range(shape[0]):
        #   print("batch_entry {}:\n".format(batch_entry), v[batch_entry])
        #
        # >> batch_entry 0:
        # >>  [[1. 1.]
        # >>  [1. 1.]]
        # >> batch_entry 1:
        # >>  [[0. 0.]
        # >>  [0. 0.]]
        should_reset = jax.tree_multimap(_validate_and_conform, should_reset,
                                         state)
        if self._is_batched(state):
            batch_size = jax.tree_leaves(inputs)[0].shape[0]
        else:
            batch_size = None
        initial_state = jax.tree_multimap(lambda s, i: i.astype(s.dtype),
                                          state,
                                          self.initial_state(batch_size))
        state = jax.tree_multimap(jnp.where, should_reset, initial_state,
                                  state)
        return self.core(inputs, state)
Example #17
0
 def optimizer_state(self, new_optimizer_state):
     if jax.tree_structure(new_optimizer_state) != jax.tree_structure(
             self.optimizer_state):
         raise AttributeError(
             "cannot set optimizer_state attr: mismatch in tree structure")
     self._optimizer_state = new_optimizer_state
Example #18
0
 def params(self, new_params):
     if jax.tree_structure(new_params) != jax.tree_structure(self._params):
         raise TypeError(
             "new params must have the same structure as old params")
     self._params = new_params
Example #19
0
 def catch_treedef(scopes, *args):
   treedef = jax.tree_structure(scopes)
   return wrapper(scopes, treedef, *args)
Example #20
0
    def make(
        factors: Iterable[FactorBase],
        use_onp: bool = True,
    ) -> "StackedFactorGraph":
        """Create a factor graph from a set of factors."""

        # Start by grouping our factors and grabbing a list of (ordered!) variables
        factors_from_group: DefaultDict[GroupKey,
                                        List[FactorBase]] = defaultdict(list)
        variables_ordered_set: Dict[VariableBase, None] = {}
        for factor in factors:
            # Each factor is ultimately just a PyTree node; in order for a set of
            # factors to be batchable, they must share the same:
            group_key: GroupKey = (
                # (1) Tree structure: this encompasses the factor class, variable types
                jax.tree_structure(factor),
                # (2) Leaf shapes: array shapes must match in order to be stacked
                tuple(leaf.shape for leaf in jax.tree_leaves(factor)),
            )

            # Record factor and variables
            factors_from_group[group_key].append(factor)
            for v in factor.variables:
                variables_ordered_set[v] = None
        variables = list(variables_ordered_set.keys())

        # Make sure factors are unique
        for factors in factors_from_group.values():
            assert len(factors) == len(set(factors))

        # Fields we want to populate
        stacked_factors: List[FactorStack] = []
        jacobian_coords: List[sparse.SparseCooCoordinates] = []

        # Create storage metadata: this determines which parts of our storage object is
        # allocated to each variable type
        storage_metadata = StorageMetadata.make(variables, local=False)
        local_storage_metadata = StorageMetadata.make(variables, local=True)

        # Prepare each factor group
        residual_offset = 0
        for group_key, group in factors_from_group.items():
            # Make factor stack
            stacked_factors.append(
                FactorStack.make(
                    group,
                    storage_metadata,
                    use_onp=use_onp,
                ))

            # Compute Jacobian coordinates
            #
            # These should be N pairs of (row, col) indices, where rows correspond to
            # residual indices and columns correspond to local parameter indices
            jacobian_coords.extend(
                FactorStack.compute_jacobian_coords(
                    factors=group,
                    local_storage_metadata=local_storage_metadata,
                    row_offset=residual_offset,
                ))
            residual_offset += stacked_factors[-1].get_residual_dim()

        jacobian_coords_concat: sparse.SparseCooCoordinates = jax.tree_multimap(
            lambda *arrays: onp.concatenate(arrays, axis=0), *jacobian_coords)

        return StackedFactorGraph(
            factor_stacks=stacked_factors,
            jacobian_coords=jacobian_coords_concat,
            local_storage_metadata=local_storage_metadata,
            residual_dim=residual_offset,
        )