def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if math.backend_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = math.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and math.backend_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of activations. Returns: Tensor of same shape and dtype as the input. """ if self._mode != 'train': return x state, rng = self.state, self.rng rate = self._initial_rate if isinstance(state, dict) and self._name in state: rate = state[self._name] mask_shape = list(x.shape) for axis in self._shared_axes: mask_shape[axis] = 1 if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate) else: keep_prob = 1.0 - rate keep = math.random.bernoulli(rng, keep_prob, tuple(mask_shape)) if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(keep, keep_prob) mask = keep.astype(x.dtype) / keep_prob return x * mask
def forward_with_state(self, inputs, weights, state, rng): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. if inputs.shape[1] == 1: return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1) else: emb = [] for i in range(inputs.shape[0]): emb.append(jax.lax.dynamic_slice_in_dim( weights[0], state[i], inputs.shape[1], axis=0)) return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
def save_state(self, keep, prefix='model'): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*math.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if math.backend_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir weights_file = os.path.join(output_dir, prefix + '.pkl.gz') # This dict will be stored as the model. trainer_state_dict = make_trainer_state_dict(step, opt_state, history, model_state, self._input_signature) self._save_state_dict(trainer_state_dict, weights_file) if keep: weights_file = os.path.join(output_dir, '{}_{}.pkl.gz'.format(prefix, step)) self._save_state_dict(trainer_state_dict, weights_file)
def forward_and_backward(self, inputs, ct, state, new_state, rng=None, **kwargs): assert math.backend_name() == 'jax', ( 'JAX backend is required to use forward_and_backward.') if ct is not None and new_state is not tl.EMPTY_STATE: recovered_rng = new_state is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1]) is_same = is_same.astype(np.float32) # Divides by zero if rngs are not the same, which results in NaNs. inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same) def _do_forward(x): # pylint: disable=invalid-name res, _ = self.forward_with_state(x, state=state, rng=rng, **kwargs) return res output, vjpfun = jax.vjp(_do_forward, inputs) return output, vjpfun(ct)[0]
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): del weights q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def forward_with_state(self, inputs, weights, state, rng): del weights q, k, v = inputs if self._mode == 'predict': state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[0, state, :], 1), state + 1)
def _l2_norm(self, flat_list): """Returns the aggregate L2 norm of a list of tensors.""" if math.backend_name() == 'jax': norm = np.sqrt(sum(np.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = np.sqrt(sum(np.sum(x * x) for x in flat_list)) return norm
def _do_custom_gradients(self, x, weights, state, rng): """Calls this layer for a forward pass, but with custom gradients.""" assert math.backend_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms @jax.custom_transforms def _do_forward(y, weights): old_weights, old_state, old_rng = self._weights, self._state, self._rng res = self.forward(y, weights) s = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng return res, s # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" old_weights, old_state, old_rng = self._weights, self._state, self._rng output = self.forward(y, weights) new_state = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, rng) return res return (output, new_state), vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) output, state = _do_forward(x, weights) state = jax.lax.stop_gradient(state) return output, state
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def one_hot(x, n_categories, dtype=np.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = np.arange(n_categories) if math.backend_name() == 'jax': # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return np.array(x[..., np.newaxis] == indices_less_than_n, dtype)
def f(x): if n_devices > 1 and math.backend_name() == 'jax': return _multi_device_put(x) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices, ) + x.shape) else: return x
def forward(self, x): """Dropout, with broadcasting to save memory.""" if self._mode == 'train' and self._rate > 0.0: noise_shape = list(x.shape) for dim in self._broadcast_dims: noise_shape[dim] = 1 if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(self.rng, 1.0 - self._rate) else: keep_prob = 1.0 - self._rate keep = random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(keep, keep_prob) multiplier = keep.astype(x.dtype) / keep_prob return x * multiplier else: return x
def forward_and_backward(self, inputs, grad, state=base.EMPTY_STATE, new_state=base.EMPTY_STATE, rng=None): del new_state assert math.backend_name() == 'jax', ( 'JAX backend is required to use forward_and_backward.') # Simultaneous forward pass and backprop through the attention mechanism. def _do_forward(x): # pylint: disable=invalid-name res, _ = self.forward_with_state(x, state=state, rng=rng) return res output, vjpfun = jax.vjp(_do_forward, inputs) return output, vjpfun(grad)[0]
def _jax_and_tf_configure_for_devices(): if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.enable_eager_execution and math.backend_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU')
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() output_dir = _output_dir_or_default() if FLAGS.use_tpu and math.backend_name() == 'tf': _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" assert math.backend_name() == 'jax', ( 'JAX backend is required to use the predict mode.') for x in inputs: assert x.shape[1] == 1, ( 'In predict mode the input sequence must be of length 1.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, index) = state ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :]) vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :]) mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1) return (ks, vs, mask, index + 1)
def policy(self, trajectory): """Chooses an action to play after a trajectory.""" model = self._policy_collect_model model.weights = self._policy_trainer.model_weights tr_slice = trajectory[-self._max_slice_length:] trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) # Add batch dimension to trajectory_np and run the model. pred = model(trajectory_np.observations[None, ...], n_accelerators=1) # Pick element 0 from the batch (the only one), last (current) timestep. pred = pred[0, -1, :] sample = self._policy_dist.sample(pred) result = (sample, pred) if math.backend_name() == 'jax': result = math.nested_map(lambda x: x.copy(), result) return result
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `(queries, keys)`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention activations (based on query-key pairs) before dotting them with values. mode: Either 'train' or eval'. Dropout applies only in 'train' mode. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) return out
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.concatenate(embs, -1) emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) return inputs + emb, state + inputs.shape[1] elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = np.concatenate(embs, -1) # return inputs + np.reshape(emb, inputs.shape), state return inputs + np.concatenate([ np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1), state else: emb = np.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" if math.backend_name() != 'jax': raise ValueError(f'JAX backend is required in predict mode, but found ' f'backend ({math.backend_nameO()}).') for x in inputs: if x.shape[1] != 1: raise ValueError(f'In predict mode, input sequence must have length 1, ' f'instead has length {x.shape[1]}.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, seq_indices) = state batch_indices = jnp.arange(ks.shape[0]) ks = jax.ops.index_update( ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :]) vs = jax.ops.index_update( vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :]) mask = jax.ops.index_update( mask, jax.ops.index[batch_indices, :, seq_indices], 1) return (ks, vs, mask, seq_indices + 1)
def save_state(self, keep): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*math.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if math.backend_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.CrossEntropyLoss, inputs=trax_inputs.inputs, optimizer=trax_opt.Adafactor, lr_schedule=lr.MultifactorSchedule, trainer_class=Trainer, steps=1000, checkpoints_at=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, save_backward_graph=False, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None, custom_train_fn=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). trainer_class: The trainer class to use. steps: int, total number of training steps. checkpoints_at: list of integers. Save a checkpoint for each training step in the list. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. has_weights: bool, whether weights are included in the inputs. nontrainable_param_map: dict, mapping from model nontrainable parameter names to control names in PolicySchedule. id_to_mask: id to mask out (None by default). metrics: optionally override the default metrics dictionary. checkpoint_highest: save the checkpoint highest at this metric. checkpoint_lowest: save the checkpoint lowest at this metric. custom_train_fn: custom train function to call, entirely bypassing this one Returns: trax.TrainerState """ if custom_train_fn is not None: return custom_train_fn(output_dir, model=model) n_devices = num_devices() # TODO(lukaszkaiser): remove has_weights and id_to_mask (configure loss). trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=random_seed, n_devices=n_devices, checkpoints_at=checkpoints_at, has_weights=has_weights, nontrainable_param_map=nontrainable_param_map, metrics=metrics, id_to_mask=id_to_mask, checkpoint_lowest=checkpoint_lowest, checkpoint_highest=checkpoint_highest) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() try: for epoch_steps in epochs(steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Update nontrainable parameters with new history trainer.update_nontrainable_params() # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and math.backend_name() == 'jax'): trainer.save_computation_graphs(save_backward_graph) # Save Gin config trainer.save_gin() trainer.log_step('Training done') except Exception as e: raise e finally: trainer.close() return trainer.state
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest if metrics is not None: self._metrics_dict = metrics else: self._metrics_dict = _DEFAULT_METRICS self._metrics_dict['loss'] = loss_fn self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if math.backend_name() == 'jax': # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if math.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if math.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch([], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def _is_jit_init(value=None): if value is None: value = math.backend_name() == 'jax' return value