def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of activations. Returns: Tensor of same shape and dtype as the input. """ if self._mode != 'train': return x state, rng = self.state, self.rng rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] mask_shape = list(x.shape) for axis in self._shared_axes: mask_shape[axis] = 1 if fastmath.backend_name() == 'jax': keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate) else: keep_prob = 1.0 - rate keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape)) if fastmath.backend_name() == 'jax': keep_prob = jax.lax.tie_in(keep, keep_prob) mask = keep.astype(x.dtype) / keep_prob return x * mask
def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if fastmath.backend_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = fastmath.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and fastmath.backend_name() == 'jax': raise ValueError('JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def policy(self, trajectory, temperature=1): """Chooses an action to play after a trajectory.""" tr_slice = trajectory[-self._max_slice_length:] trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) # Add batch dimension to trajectory_np and run the model. obs = trajectory_np.observations[None, ...] values = self._run_value_model(obs, use_eval_model=False) # We insisit that values and observations have the shape # (batch, length, ...), where the length is the number of subsequent # observations on a given trajectory assert values.shape[:1] == obs.shape[:1] # We select the last element in the batch and the value # related to the last (current) observation values = values[0, -1, :] # temperature == 0 is used in another place in order to trigger eval if np.random.random_sample() < self._exploration_rate(self._epoch) and \ temperature == 1: sample = np.array(self.task.action_space.sample()) else: # this is our way of doing the argmax sample = jnp.argmax(values) result = (sample, values) if fastmath.backend_name() == 'jax': result = fastmath.nested_map(lambda x: x.copy(), result) return result
def mean_or_pmean(n_devices, x, axis=None): """jnp.mean or pmean. `x` is a distributed value. Directly calling jnp.mean on `x` means stacking x's components together to form a large array and then doing jnp.mean on it. In TF, stacking `x` will introduce D2H copy, so we use a collective (pmean) here instead of directly calling jnp.mean for TF. Args: n_devices: number of devices. x: a distributed array. axis: the axis to reduce. Can only be 0 or None. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def one_hot(x, n_categories, dtype=jnp.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = jnp.arange(n_categories) if fastmath.backend_name() == 'jax': # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
def _l2_norm(self, flat_list): """Returns the aggregate L2 norm of a list of tensors.""" if fastmath.backend_name() == 'jax': norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list)) return norm
def f(x): if n_devices > 1 and fastmath.backend_name() == 'jax': return _multi_device_put(x) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices, ) + x.shape) else: return x
def save_state(self, keep, prefix='model'): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState( *fastmath.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if fastmath.backend_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir weights_file = os.path.join(output_dir, prefix + '.pkl.gz') # This dict will be stored as the model. trainer_state_dict = make_trainer_state_dict(step, opt_state, history, model_state, self._input_signature) self._save_state_dict(trainer_state_dict, weights_file) if keep: weights_file = os.path.join(output_dir, '{}_{}.pkl.gz'.format(prefix, step)) self._save_state_dict(trainer_state_dict, weights_file)
def forward(self, inputs): q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if (FLAGS.enable_eager_execution and fastmath.backend_name() in ('numpy', 'jax')): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU')
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() output_dir = _output_dir_or_default() if FLAGS.use_tpu and fastmath.backend_name() == 'tf': _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def policy(self, trajectory, temperature=1.0): """Chooses an action to play after a trajectory.""" model = self._policy_collect_model if temperature != 1.0: # When evaluating (t != 1.0), don't collect stats model = self._policy_eval_model model.state = self._policy_collect_model.state model.replicate_weights(self._policy_trainer.model_weights) tr_slice = trajectory[-self._max_slice_length:] trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) # Add batch dimension to trajectory_np and run the model. pred = model(trajectory_np.observations[None, ...]) # Pick element 0 from the batch (the only one), last (current) timestep. pred = pred[0, -1, :] sample = self._policy_dist.sample(pred, temperature=temperature) result = (sample, pred) if fastmath.backend_name() == 'jax': result = fastmath.nested_map(lambda x: x.copy(), result) return result
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if fastmath.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out, dots
def forward(self, inputs): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def mean_or_pmean(n_devices, x, axis=None): """Computes the mean of a distributed value ``x``. Args: n_devices: Number of devices. x: Distributed array. axis: Axis along which to compute means; can only be ``0`` or ``None``. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" if fastmath.backend_name() != 'jax': raise ValueError(f'JAX backend is required in predict mode, but found ' f'backend ({fastmath.backend_nameO()}).') for x in inputs: if x.shape[1] != 1: raise ValueError(f'In predict mode, input sequence must have length 1, ' f'instead has length {x.shape[1]}.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, seq_indices) = state batch_indices = jnp.arange(ks.shape[0]) ks = jax.ops.index_update( ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :]) vs = jax.ops.index_update( vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :]) mask = jax.ops.index_update( mask, jax.ops.index[batch_indices, :, seq_indices], 1) return (ks, vs, mask, seq_indices + 1)
def forward(self, inputs): rng, state = self.rng, self.state embs = [] for ax_emb in self.weights: ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = jnp.concatenate(embs, -1) emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) self.state = state + inputs.shape[1] return inputs + emb elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = jnp.concatenate(embs, -1) # return inputs + jnp.reshape(emb, inputs.shape), state return inputs + jnp.concatenate([ jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1) else: emb = jnp.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, jnp.full((), keep_prob, dtype=inputs.dtype)) keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + jnp.reshape(emb * multiplier, inputs.shape)
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores tensors with cached values of keys and values, as well as the mask and an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. Mask ensures that attention will only look at keys upto index. During update, we append new_keys and new_values to keys and values at position given by index. We also update mask (which starts as all-0s) to be 1 at the new keys positions. And we increment index by length of new keys. Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, mask, index) Returns: Updated state. """ if fastmath.backend_name() != 'jax': raise ValueError(f'JAX backend is required in predict mode, but found ' f'backend ({fastmath.backend_nameO()}).') # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, mask, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = jax.lax.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = jax.lax.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) # Mask is of shape [batch_size, 1 (for heads), length]. new_mask = jnp.ones((mask.shape[0], mask.shape[1], length)) mask = jax.lax.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2) return (ks, vs, mask, idx + length)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.CrossEntropyLoss(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule=lr.multifactor(), trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, custom_train_fn=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. custom_train_fn: custom train function to call, entirely bypassing this one Returns: trax.TrainerState """ if custom_train_fn is not None: return custom_train_fn(output_dir, model=model) n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.backend_name() == 'jax'): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def test_use_backend_str(self): with fastmath.use_backend('tensorflow-numpy'): self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy')
def test_use_backend_enum(self): with fastmath.use_backend(fastmath.Backend.NUMPY): self.assertEqual(fastmath.backend_name(), 'numpy')
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def test_backend_can_be_set(self): self.assertEqual(fastmath.backend_name(), 'jax') fastmath.set_backend('tensorflow-numpy') self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy') fastmath.set_backend(None) self.assertEqual(fastmath.backend_name(), 'jax')
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.backend_name() == 'jax': # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.CrossEntropyLoss(), inputs=trax_inputs.batcher, optimizer=trax_opt.Adafactor, lr_schedule_fn=lr.multifactor, trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, use_loop=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule_fn: A learning rate schedule function, that when called returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. Returns: trax.TrainerState or training.Loop if use_loop is True """ if use_loop: n_devices = num_devices() or fastmath.device_count() # Prepare the training task. # Inputs is either an Inputs instance or a function that returns it. if callable( inputs): # If we pass a function, e.g., through gin, call it. inputs = inputs() train_task = training.TrainTask(inputs.train_stream(n_devices), loss_layer=loss_fn, optimizer=optimizer(), lr_schedule=lr_schedule_fn(), n_steps_per_checkpoint=eval_frequency) # Prepare the evaluation. metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS names, metrics = zip(*metrics_dict.items()) eval_task = training.EvalTask(inputs.eval_stream(n_devices), metrics, metric_names=names, n_eval_batches=eval_steps) # Prepare the training loop. checkpoint_at = None if checkpoints_at is not None: checkpoint_at = lambda step: step in checkpoints_at loop = training.Loop(model(mode='train'), [train_task], eval_model=model(mode='eval'), eval_tasks=[eval_task], output_dir=output_dir, checkpoint_at=checkpoint_at, n_devices=n_devices, random_seed=random_seed) # Train and return the loop. loop.run(steps) return loop n_devices = num_devices() trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and fastmath.backend_name() == 'jax'): trainer.save_computation_graphs() # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state