def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of activations. Returns: Tensor of same shape and dtype as the input. """ if self._mode != 'train': return x state, rng = self.state, self.rng rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] mask_shape = list(x.shape) for axis in self._shared_axes: mask_shape[axis] = 1 if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate) else: keep_prob = 1.0 - rate keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in(keep, keep_prob) mask = keep.astype(x.dtype) / keep_prob return x * mask
def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring jax.config.enable_omnistaging() if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if (FLAGS.enable_eager_execution and (fastmath.is_backend(Backend.NUMPY) or fastmath.is_backend(Backend.JAX))): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU')
def f(x): if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): return _multi_device_put(x) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) else: return x
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = jnp.arange(n_categories) if fastmath.is_backend(fastmath.Backend.JAX): # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
def f(x): if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX): return jax.device_put_replicated(x, jax.local_devices()) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape) else: return x
def save_state(self, keep, prefix='model'): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState( *fastmath.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if fastmath.is_backend(fastmath.Backend.JAX): opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir weights_file = os.path.join(output_dir, prefix + '.pkl.gz') # This dict will be stored as the model. trainer_state_dict = make_trainer_state_dict(step, opt_state, history, model_state, self._input_signature) self._save_state_dict(trainer_state_dict, weights_file) if keep: weights_file = os.path.join(output_dir, '{}_{}.pkl.gz'.format(prefix, step)) self._save_state_dict(trainer_state_dict, weights_file)
def forward(self, inputs): """Returns attention-computed activations. Args: inputs: A (queries, keys, values) tuple. """ q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def _l2_norm(self, flat_list): """Returns the aggregate L2 norm of a list of tensors.""" if fastmath.is_backend(fastmath.Backend.JAX): norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list)) return norm
def _causal_mask(length): # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) else: return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
def network_policy( collect_model, policy_distribution, loop, trajectory_np, head_index=0, temperature=1.0, ): """Policy function powered by a neural network. Used to implement Agent.policy() in policy-based agents. Args: collect_model: the model used for collecting trajectories policy_distribution: an instance of trax.rl.distributions.Distribution loop: trax.supervised.training.Loop used to train the policy network trajectory_np: an instance of trax.rl.task.TrajectoryNp head_index: index of the policy head a multihead model. temperature: temperature used to sample from the policy (default=1.0) Returns: a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training. """ if temperature == 1.0: model = collect_model else: # When evaluating (t != 1.0), use the evaluation model instead of the # collection model - some models accumulate normalization statistics # during data collection, and we don't want to do it in eval to avoid data # leakage. model = loop.eval_model model.state = collect_model.state # Copying weights from loop.model should work, because the raw model's # weights should be updated automatically during training, but it doesn't. # TODO(pkozakowski): Debug. acc = loop._trainer_per_task[0].accelerated_model_with_loss # pylint: disable=protected-access model.weights = acc._unreplicate(acc.weights[0]) # pylint: disable=protected-access # Add batch dimension to trajectory_np and run the model. pred = model(trajectory_np.observations[None, ...]) if isinstance(pred, (tuple, list)): # For multihead models, extract the policy head output. pred = pred[head_index] assert pred.shape == ( 1, trajectory_np.observations.shape[0], policy_distribution.n_inputs ) # Pick element 0 from the batch (the only one), last (current) timestep. pred = pred[0, -1, :] sample = policy_distribution.sample(pred, temperature=temperature) result = (sample, pred) if fastmath.is_backend(fastmath.Backend.JAX): # The result is composed of mutable numpy arrays. We copy them to avoid # accidental modification. result = fastmath.nested_map(lambda x: x.copy(), result) return result
def init_host_and_devices(n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. host_count: Number of hosts in this computation. n_devices: The passed in value of n_devices or a computed default (for this host). random_seed: The passed in value of random_seed or a computed default. """ if fastmath.is_backend(fastmath.Backend.JAX): host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) logging.info( 'Initializing hosts and devices: host_id %d, host_count %d, ' 'is_chief %d', host_id, host_count, is_chief) device_count = fastmath.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX): raise ValueError('JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return (is_chief, host_count, n_devices, _init_random_number_generators(random_seed))
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() output_dir = _output_dir_or_default() if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _to_bits(self, weights): """Converts a list of weights to bit-cast weights and their types.""" # This is currently needed to pickle bfloat16 arrays from JAX. # TODO(lukaszkaiser): remove once it is not needed (the following unit test # checks it: training_test/test_restores_step_bfloat16). if not fastmath.is_backend(fastmath.Backend.JAX): return weights bits = [] for w in weights: if w.dtype == jnp.bfloat16: bits.append((jax.lax.bitcast_convert_type(w, np.uint16), 'bfloat16')) else: # for non-bfloat16 weights, be compatible with earlier checkpoints bits.append(w) return bits
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() # Create a JAX GPU cluster if using JAX and given a chief IP. if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip: _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id, FLAGS.gpu_cluster_chief_ip, FLAGS.gpu_cluster_n_hosts, FLAGS.gpu_cluster_port) if FLAGS.disable_jit: fastmath.disable_jit() output_dir = _output_dir_or_default() if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP): _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _from_bits(self, bits_and_types): """Converts a list of bit-cast weights and their types back to weights.""" # This is the reverse of _to_bits, see above for explanation. if not fastmath.is_backend(fastmath.Backend.JAX): return bits_and_types weights = [] for bits_and_dtype in bits_and_types: if isinstance(bits_and_dtype, tuple): bits, dtype = bits_and_dtype assert dtype == 'bfloat16' w = jax.lax.bitcast_convert_type(bits, jnp.bfloat16) weights.append(w) else: weights.append(bits_and_dtype) return weights
def _forward_and_or_backward(layer): """Create forward_and_or_backward for layers that don't define it.""" # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True): """Performs batched forward and/or backward passes. Args: inputs: inputs to the attention layer weights: weights for the attention layer state: state of the attention layer rng: PRNG key for the layer (shared across all examples and heads) output_grad: gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff `output_grad` is not None. compute_output: bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output). update_state: bool: whether to return an updated layer state. Returns: A tuple (output, new_state, inputs_grad, weights_grad). - output is not None iff compute_output is True - new_state is not None iff update_state is True - inputs_grad & weights_grad are not None iff output_grad is not None """ # We need a layer pure_fn but only for inputs and weights. def pure_fn_without_state_and_rng(x, w): return layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the layer pure_fn. output, vjp_fn, new_state = fastmath.vjp( pure_fn_without_state_and_rng, inputs, weights, has_aux=True) output = output if compute_output else None new_state = new_state if update_state else None # The vjp function returns gradients with respect to inputs and weights. if output_grad is not None: grads_inputs, grads_weights = vjp_fn(output_grad) else: grads_inputs, grads_weights = None, None return (output, new_state, grads_inputs, grads_weights) return forward_and_or_backward
def _l2_norm(self, flat_list): """Returns an L2-like norm of all elements of all tensors in `flat_list`. Args: flat_list: Collection of tensors as a flat list (rather than, e.g., a tree). Returns: A scalar value computed as if all the tensors in `flat_list` were joined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ if fastmath.is_backend(fastmath.Backend.JAX): norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list)) return norm
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if fastmath.is_backend(fastmath.Backend.JAX): mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def policy(self, trajectory, temperature=1.0): """Chooses an action to play after a trajectory.""" model = self._policy_collect_model if temperature != 1.0: # When evaluating (t != 1.0), don't collect stats model = self._policy_eval_model model.state = self._policy_collect_model.state model.replicate_weights(self._policy_trainer.model_weights) tr_slice = trajectory[-self._max_slice_length:] trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) # Add batch dimension to trajectory_np and run the model. pred = model(trajectory_np.observations[None, ...]) # Pick element 0 from the batch (the only one), last (current) timestep. pred = pred[0, -1, :] sample = self._policy_dist.sample(pred, temperature=temperature) result = (sample, pred) if fastmath.is_backend(fastmath.Backend.JAX): result = fastmath.nested_map(lambda x: x.copy(), result) return result
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def forward(self, inputs): rng, state = self.rng, self.state embs = [] for ax_emb in self.weights: ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = jnp.concatenate(embs, -1) emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) self.state = state + inputs.shape[1] return inputs + emb elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = jnp.concatenate(embs, -1) # return inputs + jnp.reshape(emb, inputs.shape), state return inputs + jnp.concatenate([ jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1) else: emb = jnp.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in( inputs, jnp.full((), keep_prob, dtype=inputs.dtype)) keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + jnp.reshape(emb * multiplier, inputs.shape)
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores tensors with cached values of keys and values, as well as the mask and an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. Mask ensures that attention will only look at keys upto index. During update, we append new_keys and new_values to keys and values at position given by index. We also update mask (which starts as all-0s) to be 1 at the new keys positions. And we increment index by length of new keys. Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, mask, index) Returns: Updated state. """ if not fastmath.is_backend(fastmath.Backend.JAX): raise ValueError(f'JAX backend is required in predict mode, but found ' f"backend ({fastmath.backend()['name']}).") # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, mask, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) # Mask is of shape [batch_size, 1 (for heads), length]. new_mask = jnp.ones((mask.shape[0], mask.shape[1], length)) mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2) return (ks, vs, mask, idx + length)
def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial(blocks + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Blocks are pairs consisting of a list of standard (arbitrary) layers and a list of reversible layers which help save memory thanks to being reversible. Args: blocks: A list of pairs of lists of standard and reversible layers. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. """ # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks] self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.device_count() self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks]) # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_layer_fns = fastmath.nested_map( lambda layer: self._pjit(layer.pure_fn), self._blocks) # Create per-layer optimizers and replicate opt_params. def _make_optimizer(layer): opt = optimizer_fn() opt.tree_init(layer.weights) return opt self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks) self._replicated_opt_params = fastmath.nested_map( lambda opt: self._replicate(opt.opt_params), self._optimizers) self._loss_opt = _make_optimizer(loss_layer) self._replicated_loss_opt_params = self._replicate( self._loss_opt.opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". # Reversible layers define a reverse_and_fbo function that also reverses. self._fbos = [] for i, (std_layer, rev_layers) in enumerate(self._blocks): (std_opt, rev_opts) = self._optimizers[i] std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices) rev_and_fbos = [] for layer, opt in zip(rev_layers, rev_opts): rev_and_fbos.append( self._pjit( _reverse_and_fbo_with_layer_and_opt( layer, opt, self._n_devices))) self._fbos.append((self._pjit(std_fbo), rev_and_fbos)) loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt, self._n_devices, 'loss') self._loss_fbo = self._pjit(loss_fbo)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.WeightedCategoryCrossEntropy(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, permanent_checkpoints_at=None, eval_steps=10, eval_frequency=100, permanent_checkpoint_frequency=None, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, init_checkpoint=None, callbacks=None, additional_train_tasks=None, additional_eval_tasks=None, additional_eval_streams=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. permanent_checkpoints_at: list of integers. Save a permanent checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. permanent_checkpoint_frequency: int, how often to save permanent checkpoints (every permanent_checkpoint_frequency steps). random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. adasum: if True, use adaptive summation for multi-device gradients. init_checkpoint: a checkpoint for fine tuning. callbacks: a list of callbacks to call during training. additional_train_tasks: additional tasks which should be performed during training. additional_eval_tasks: additional tasks which should be performed during evaluation. additional_eval_streams: List[NamedStream], additional data streams that should be used during evaluation. Can be provided independently of additional_eval_tasks. Returns: trax.TrainerState or training.Loop if use_loop is True """ if (permanent_checkpoint_frequency is not None and permanent_checkpoints_at is not None): raise ValueError('Only one of ["permanent_checkpoint_frequency", ' '"permanent_checkpoints_at"] should be set.') if use_loop: n_devices = num_devices() or fastmath.local_device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable(inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask( inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency, n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency) if additional_train_tasks is None: additional_train_tasks = [] # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) if additional_eval_tasks is None: additional_eval_tasks = [] additional_eval_tasks_from_streams = [] if additional_eval_streams is not None: for stream in additional_eval_streams: additional_eval_tasks_from_streams.append( training.EvalTask(stream.stream, metrics, metric_names=names, n_eval_batches=eval_steps, export_prefix=stream.name)) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') if init_checkpoint: model_train.init_from_file(init_checkpoint, weights_only=True) model_predict_eval.init_from_file(init_checkpoint, weights_only=True) loop = training.Loop( model_train, [train_task] + additional_train_tasks, eval_model=model_predict_eval, eval_tasks=[eval_task] + additional_eval_tasks + additional_eval_tasks_from_streams, output_dir=output_dir, checkpoint_at=checkpoint_at, permanent_checkpoint_at=permanent_checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, adasum=adasum, random_seed=random_seed, callbacks=callbacks, ) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest, init_checkpoint=init_checkpoint) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='infinite', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # assert d_model // n_heads == d_attention_key, \ # f'{d_model} // {n_heads} != {d_attention_key}' # assert d_model // n_heads == d_attention_value, \ # f'{d_model} // {n_heads} != {d_attention_value}' # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors positional_encoding = PositionalEncoding(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch([], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss(), name='CrossEntropyLoss'), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, use_memory_efficient_trainer=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. use_memory_efficient_trainer: whether to use memory-efficient trainer. Returns: trax.TrainerState or training.Loop if use_loop is True """ if use_loop: n_devices = num_devices() or fastmath.device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable( inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() opt = optimizer if use_memory_efficient_trainer else optimizer() train_task = training.TrainTask(inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=opt, lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency) # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at loop = training.Loop( model(mode='train'), [train_task], eval_model=model(mode='eval'), eval_tasks=[eval_task], output_dir=output_dir, checkpoint_at=checkpoint_at, n_devices=n_devices, loss_chunk_size=loss_chunk_size, use_memory_efficient_trainer=use_memory_efficient_trainer, random_seed=random_seed) steps_to_go = steps - loop.step if steps_to_go <= 0: log('Stop training, already reached the total training steps %d' % steps) return loop # Train and return the loop. loop.run(steps_to_go) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.is_backend(fastmath.Backend.JAX): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def __init__(self, first_layer, reversible_layers, loss_layer, optimizer_fn, n_devices=None): """Creates a ReversibleSerialTrainer and the needed optimizers. This trainer performs updates equivalent to using the default Trainer on:: tl.Serial([first_layer] + reversible_layer + [loss_layer]). It is more memory-efficient though since weights are stored on CPU and only sent to accelerator layer-by-layer. Note that the first layer and loss layer can be arbitrary layers, so they can be a `tl.Serial` combination of layers too. For now, we only support one block of reversible layers though. Args: first_layer: The first layer of the model, it can be arbitraty. reversible_layers: A list of reversible layers that are executed after the first layer. We do not keep their activations in memory and weights are moved to CPU RAM after each layer to free accelerator memory. loss_layer: The final layer of the model; it can have trainable weights but should end with a loss: it is required to produce a scalar output. optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`. n_devices: An optional integer, number of accelerator devices to use; by default, all available accelerators will be used. """ # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access self._first_layer = first_layer self._reversible_layers = reversible_layers self._loss_layer = loss_layer self._optimizer_fn = optimizer_fn self._n_devices = n_devices or fastmath.device_count() # Create accelerated versions of layers as pmaped/jited pure_fn. self._accelerated_first_layer_fn = self._pjit(first_layer.pure_fn) self._accelerated_reversible_layers_fns = [] for layer in reversible_layers: self._accelerated_reversible_layers_fns.append( self._pjit(layer.pure_fn)) # Create per-layer optimizers and replicate opt_params. self._optimizers, self._replicated_opt_params = [], [] for layer in [first_layer] + reversible_layers + [loss_layer]: optimizer = optimizer_fn() optimizer.tree_init(layer.weights) self._optimizers.append(optimizer) opt_params = self._replicate(optimizer.opt_params) self._replicated_opt_params.append(opt_params) # Forward + backward + optimizer-update functions for all layers. # We call them in short FBO for "Forward + Backward + Optimizer update". def first_fbo(inputs, weights, state, slots, opt_params, rng, step, grads): """FBO of the first layer.""" # We need the first layer's pure_fn but only for inputs and weights. def first_layer_pure_fn_without_state_and_rng(x, w): return first_layer.pure_fn(x, w, state, rng) # Calculate vector-Jacobian product of the reduced first layer pure fn. activations_after_first_layer, vjp_fn, new_state = fastmath.vjp( first_layer_pure_fn_without_state_and_rng, inputs, weights, has_aux=True) del activations_after_first_layer # unused # The vjp function returns gradients with respect to inputs and weights. _, grads_weights = vjp_fn(grads) # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the first layer optimizer, which is the first one. new_weights, new_slots, stats = self._optimizers[0].tree_update( step, grads_weights, weights, slots, opt_params) return new_weights, new_state, new_slots, stats # Accelerate the first layer FBO function and store it. self._first_fbo = self._pjit(first_fbo) # Loss layer FBO is like the first layer, but has no gradients argument # as it is the last layer and we always use 1.0 for that. On the other # hand, it adds the final activation (loss) into the returned stats. def loss_fbo(inputs, weights, state, slots, opt_params, rng, step): """FBO of the final loss layer.""" # We need a loss layer pure_fn but only for inputs and weights. def loss_pure_fn_without_state_and_rng(x, w): return loss_layer.pure_fn(x, w, state, rng) # Calculate the vector-Jacobian product of the reduced loss pure fn. loss, vjp_fn, new_state = fastmath.vjp( loss_pure_fn_without_state_and_rng, inputs, weights, has_aux=True) # The vjp function returns gradients with respect to inputs and weights. # Since loss is scalar and there are no other layers, run it at 1.0. grads_inputs, grads_weights = vjp_fn(jnp.ones((), dtype=loss.dtype)) # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients(grads_weights) # Run the loss optimizer, which is the last one since it's the last layer. new_weights, new_slots, stats = self._optimizers[-1].tree_update( step, grads_weights, weights, slots, opt_params) stats['loss'] = loss return new_weights, new_state, new_slots, grads_inputs, stats # Accelerate the loss layer FBO function and store it. self._loss_fbo = self._pjit(loss_fbo) # Reversible layers define a reverse_and_fbo function that both reverses # and runs the forward-backward pass and applied the optimizer. # This function uses the `reverse_and_grad` method of reversible layers. def reverse_and_fbo_with_layer_and_opt(layer, optimizer): """Create the reverse_and_fbo function for a given layer and optimizer.""" def reverse_and_fbo(output, weights, state, new_state, slots, opt_params, rng, step, grads): """Reverse and FBO of the layer.""" # Call the reverse_and_grad method of the layer. inputs, (grads_inputs, grads_weights) = layer.reverse_and_grad(output, grads, weights, state, new_state, rng=rng) # For non-trainable layers, return the calculated arguments. if not weights: return weights, slots, inputs, grads_inputs, {} # In multi-device setting, average gradients from multiple devices. if self._n_devices > 1: grads_weights = _average_multidevice_gradients( grads_weights) # Run the optimizer. new_weights, new_slots, stats = optimizer.tree_update( step, grads_weights, weights, slots, opt_params) return new_weights, new_slots, inputs, grads_inputs, stats return reverse_and_fbo # Accelerate the reverse_and_fbo functions and store them. self._reverse_and_fbos = [] for layer, opt in zip(reversible_layers, self._optimizers[1:-1]): reverse_and_fbo = reverse_and_fbo_with_layer_and_opt(layer, opt) self._reverse_and_fbos.append(self._pjit(reverse_and_fbo))
def on_cpu(x): """Puts ``x`` in CPU memory in JAX.""" if fastmath.is_backend(fastmath.Backend.JAX): return jax.device_put(x, jax.devices('cpu')[0]) else: return x