Example #1
0
  def _init_host_and_devices(self, n_devices=None, random_seed=None):
    """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
    if backend.get_name() == 'jax':
      host_id = jax.host_id()
      host_count = jax.host_count()
    else:
      host_id = 0
      host_count = 1
    is_chief = (host_id == 0)

    device_count = backend.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count and backend.get_name() == 'jax':
      raise ValueError('JAX cannot work yet with n_devices != all devices: '
                       '%d != %d' % (n_devices, device_count))

    if random_seed is None and host_count > 1:
      random_seed = int(1e6 * (host_id + time.time())) % 2**32
    return is_chief, n_devices, init_random_number_generators(random_seed)
Example #2
0
def main(_):

  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

  tf.config.optimizer.set_experimental_options(
      {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {'layout_optimizer': FLAGS.tf_opt_layout}
  )

  set_tf_allow_float64(FLAGS.tf_allow_float64)

  _setup_gin()

  if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trainer_lib.log('Using --output_dir %s' % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

  if FLAGS.use_tpu and backend.get_name() == 'tf':
    worker_cpu = tf_init_tpu()
    with tf.device(worker_cpu):
      if trainer_lib.num_devices() == 1:
        # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
        # the TPU core the default device here.
        with tf.device('/device:TPU:0'):
          trainer_lib.train(output_dir=output_dir)
      else:
        trainer_lib.train(output_dir=output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Example #3
0
def main(_):

    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.enable_eager_execution:
        tf.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)

    tf.config.optimizer.set_experimental_options(
        {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {'layout_optimizer': FLAGS.tf_opt_layout})

    _setup_gin()

    if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trainer_lib.log('Using --output_dir %s' % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')

    trainer_lib.train(output_dir=output_dir)
Example #4
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if backend.get_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Example #5
0
    def forward_and_backward(self,
                             inputs,
                             ct,
                             state,
                             new_state,
                             rng=None,
                             **kwargs):
        assert backend.get_name() == 'jax', (
            'JAX backend is required to use forward_and_backward.')

        if ct is not None and new_state is not tl.EMPTY_STATE:
            recovered_rng = new_state
            is_same = (rng[0] == recovered_rng[0]) & (rng[1]
                                                      == recovered_rng[1])
            is_same = is_same.astype(np.float32)
            # Divides by zero if rngs are not the same, which results in NaNs.
            inputs = (inputs[0] / is_same, inputs[1] / is_same,
                      inputs[2] / is_same)

        def _do_forward(x):  # pylint: disable=invalid-name
            res, _ = self.forward_with_state(x, state=state, rng=rng, **kwargs)
            return res

        output, vjpfun = jax.vjp(_do_forward, inputs)
        return output, vjpfun(ct)[0]
Example #6
0
 def f(x):
     if n > 1 and backend.get_name() == 'jax':
         return _multi_device_put(x)
     elif n > 1:
         return np.broadcast_to(x, (n, ) + x.shape)
     else:
         return x
Example #7
0
def one_hot(x, size, dtype=np.float32):  # pylint: disable=invalid-name
  """Make a n+1 dim one-hot array from n dim int-categorical array."""
  arange_size = np.arange(size)
  if backend.get_name() == 'jax':
    # Work around a jax broadcasting issue.
    arange_size = jax.lax.tie_in(x, arange_size)
  return np.array(x[..., np.newaxis] == arange_size, dtype)
Example #8
0
 def forward_with_state(self,
                        inputs,
                        weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     if self._mode in ('train', 'eval'):
         x = inputs
         symbol_size = np.shape(x)[1]
         px = weights[:, :symbol_size, :]
         if self._dropout == 0:
             return (x + px, state)
         else:
             noise_shape = list(px.shape)
             for dim in self._dropout_broadcast_dims:
                 noise_shape[dim] = 1
             keep_prob = 1.0 - self._dropout
             if backend.get_name() == 'jax':
                 keep_prob = jax.lax.tie_in(
                     x, np.full((), keep_prob, dtype=x.dtype))
             keep = backend.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
             multiplier = keep.astype(x.dtype) / keep_prob
             return (x + px * multiplier, state)
     else:
         assert self._mode == 'predict'
         assert self._dropout == 0
         # State in this class is only used for fast inference. In that case,
         # the model is called with consecutive elements position-by-position.
         # This positional encoding layer needs to store the index of the current
         # position then and increment it on each call -- that's how state is used
         # and updated below.
         return (inputs + np.expand_dims(weights[:, state, :], 1),
                 state + 1)
Example #9
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        del weights
        q, k, v = inputs
        if self._mode in ('train', 'eval'):
            mask_size = q.shape[-2]
            # Not all backends define np.tril. However, using onp.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if backend.get_name() == 'jax':
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=onp.bool_),
                               k=0)
            else:
                mask = onp.tril(onp.ones((1, mask_size, mask_size),
                                         dtype=onp.bool_),
                                k=0)
        else:
            assert self._mode == 'predict'
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
Example #10
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if backend.get_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Example #11
0
    def _do_custom_gradients(self, x, weights, state, **kwargs):
        """Calls this layer for a forward pass, but with custom gradients."""
        assert backend.get_name() == 'jax', (
            'Custom gradients are only supported in JAX for now.')

        # TODO(wangpeng): JAX doesn't support custom grads for functions with
        #   auxiliary output yet (https://github.com/google/jax/issues/844). Will
        #   remove the constraints on state below when this feature is added to
        #   JAX.

        assert not jax.tree_util.tree_leaves(state), (
            'Custom gradients require trivial start state. Got %s' %
            str(state))

        def check_end_state(output_state):
            output, state = output_state
            assert not jax.tree_util.tree_leaves(state), (
                'Custom gradients require trivial end state. Got %s' %
                str(state))
            return output

        # See this link for how custom transformations are defined in JAX:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
        # Note that we capture the kwargs and don't calculate gradients wrt. them.
        @jax.custom_transforms
        def _do_forward(y, weights):
            return check_end_state(
                self.forward_with_state(y,
                                        weights=weights,
                                        state=state,
                                        **kwargs))

        # This is the custom gradient (vector-jacobian product in JAX) function.
        # For the exact specification of this custom transformation see this link:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
        def do_forward_vjp(y, weights):
            """Custom gradient (vjp) function."""
            stash = None
            if Layer._STASH_IN is None:
                Layer._STASH_IN = stash = {}
            output = check_end_state(
                self.forward_with_state(y,
                                        weights=weights,
                                        state=state,
                                        **kwargs))
            if stash is not None:
                Layer._STASH_IN = None

            def vjpfun(grad):
                assert Layer._STASH_OUT is None
                Layer._STASH_OUT = stash
                res = self.backward(y, output, grad, weights, state, **kwargs)
                Layer._STASH_OUT = None
                return res

            return output, vjpfun

        jax.defvjp_all(_do_forward, do_forward_vjp)
        return _do_forward(x, weights), state
 def _maybe_replicate(self, x):
   if self._n_devices > 1:
     if backend.get_name() == 'jax':
       return multi_device_put(x)
     else:
       return np.broadcast_to(x, (self._n_devices,) + x.shape)
   else:
     return x
Example #13
0
def _jax_and_tf_configure_for_devices():
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
  if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')
Example #14
0
    def forward_and_backward(self, inputs, ct, state, new_state, **kwargs):
        assert backend.get_name() == 'jax', (
            'JAX backend is required to use forward_and_backward.')

        # Simultaneous forward pass and backprop through the attention mechanism.
        def _do_forward(x):  # pylint: disable=invalid-name
            res, _ = self.forward_with_state(x, state=state, **kwargs)
            return res

        output, vjpfun = jax.vjp(_do_forward, inputs)
        return output, vjpfun(ct)[0]
def _save_replicated(opt_state, step, history, model_state, n_devices,
                     output_dir, keep):
  """Saves trainer state but given a possibly replicated opt_state."""
  if n_devices > 1:
    first_replica = lambda x: x[0]
    opt_state = OptState(*layers.nested_map(first_replica, opt_state))
  # This line, while optional, allows JAX to transfer arrays from the device to
  # the host in parallel, which is particularly important for cloud TPU.
  if backend.get_name() == 'jax':
    opt_state = jax.device_get(opt_state)
  save_trainer_state(
      TrainerState(opt_state=opt_state, step=step, history=history,
                   model_state=model_state), output_dir, keep=keep)
Example #16
0
 def __init__(self, loop_stride, dropout, mode, share_qk=False, hard_k=0):
     assert backend.get_name() == 'jax', (
         'JAX backend is required to use MemoryEfficientCausalAttention.')
     super(MemoryEfficientCausalAttention, self).__init__()
     self._loop_stride = loop_stride
     if dropout >= 1.0:
         raise ValueError('Dropout rates must be lower than 1.')
     if mode == 'train':
         self.dropout = dropout
     else:
         self.dropout = None
     self._share_qk = share_qk
     self._hard_k = hard_k
Example #17
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _tf_setup_from_flags()
  _gin_parse_configs()
  _jax_and_tf_configure_for_devices()

  output_dir = _output_dir_or_default()
  if FLAGS.use_tpu and backend.get_name() == 'tf':
    _train_using_tf(output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Example #18
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference."""
    assert backend.get_name() == 'jax', (
        'JAX backend is required to use the predict mode.')
    for x in inputs:
        assert x.shape[1] == 1, (
            'In predict mode the input sequence must be of length 1.')
    # Fast inference: run with only 1 query in each step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    (ks, vs, mask, index) = state
    ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :])
    vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :])
    mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1)
    return (ks, vs, mask, index + 1)
Example #19
0
  def save_state(self, keep):
    """Save trainer state given a possibly replicated opt_state."""
    opt_state = self._opt_state
    if self.n_devices > 1:
      first_replica = lambda x: x[0]
      opt_state = OptState(*backend.nested_map(first_replica, opt_state))
    # This line, while optional, allows JAX to transfer arrays from the device
    # to the host in parallel, which is particularly important for cloud TPU.
    if backend.get_name() == 'jax':
      opt_state = jax.device_get(opt_state)
    step, history, model_state = self._step, self._history, self._model_state
    output_dir = self._output_dir

    pkl_module = utils.get_pickle_module()
    weights_file = os.path.join(output_dir, 'model.pkl')
    with tf.io.gfile.GFile(weights_file, 'wb') as f:
      pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    if keep:
      weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step))
      with tf.io.gfile.GFile(weights_file, 'wb') as f:
        pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    log('Model saved to %s' % weights_file, stdout=False)
Example #20
0
    def _do_custom_gradients(self, x, weights, state, **kwargs):
        """Calls this layer for a forward pass, but with custom gradients."""
        assert backend.get_name() == 'jax', (
            'Custom gradients are only supported in JAX for now.')

        # See this link for how custom transformations are defined in JAX:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
        # Note that we capture the kwargs and don't calculate gradients wrt. them.
        @jax.custom_transforms
        def _do_forward(y, weights):
            res = self.forward_with_state(y,
                                          weights=weights,
                                          state=state,
                                          **kwargs)
            return res

        # This is the custom gradient (vector-jacobian product in JAX) function.
        # For the exact specification of this custom transformation see this link:
        # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
        def do_forward_vjp(y, weights):
            """Custom gradient (vjp) function."""
            output, new_state = self.forward_with_state(y,
                                                        weights=weights,
                                                        state=state,
                                                        **kwargs)

            def vjpfun(grad):
                grad = grad[0]  # Ignore dummy gradient wrt state.
                res = self.backward(y, output, grad, weights, state, new_state,
                                    **kwargs)
                return res

            return (output, state), vjpfun

        jax.defvjp_all(_do_forward, do_forward_vjp)
        output, state = _do_forward(x, weights)
        state = jax.lax.stop_gradient(state)
        return output, state
Example #21
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 mask_id=None,
                 metrics=None):
        if backend.get_name() == 'jax':
            self._host_id = jax.host_id()
            self._host_count = jax.host_count()
        else:
            self._host_id = 0
            self._host_count = 1
        self._is_chief = (self._host_id == 0)

        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save_checkpoints = should_save_checkpoints
        self._should_write_summaries = should_write_summaries
        self._has_weights = has_weights
        self._mask_id = mask_id
        self._metrics_dict = _METRICS if metrics is None else metrics
        loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
        device_count = backend.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and backend.get_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))
        self._n_devices = n_devices

        # Simple differential seeding of RNG across hosts by host_id and time.
        if random_seed is None and self._host_count > 1:
            _, random_seed = divmod(
                int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32)
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, n_devices))
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = backend.nested_map(lambda x: x or 1,
                                               model_input_shape)
        model_target_shape = backend.nested_map(lambda x: x or 1,
                                                model_target_shape)

        def new_opt_state_and_model_state(input_shape, input_dtype,
                                          target_shape, target_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            dtypes = list(input_dtype) + list(target_dtype)
            shapes = list(input_shape) + list(target_shape)
            if self._has_weights:
                shapes += list(target_shape)
                dtypes += [np.float32 for _ in target_dtype]
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = backend.jit(
                new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                model_input_shape, self._inputs.input_dtype,
                model_target_shape, self._inputs.target_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  mask_id=self._mask_id) for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        # TODO(lukaszkaiser): clean this up once layer API stabilizes.
        # For now, we need to initialize metric layers somehow, so here we go.
        # We assume that they do not have any parameters, so this is a dummy.
        dummy_shapes = ((1, 2), (1, ),
                        (1, )) if self._has_weights else ((1, 2), (1, ))
        dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        m_weights, m_state = metrics_in_parallel.init(dummy_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)
Example #22
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.CrossEntropyLossScalar,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.Adafactor,
          lr_schedule=lr.MultifactorSchedule,
          trainer_class=Trainer,
          train_steps=1000,
          save_steps=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          save_backward_graph=False,
          has_weights=False,
          nontrainable_param_map=None,
          mask_id=None,
          metrics=None):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    train_steps: int, total number of training steps.
    save_steps: list of integers. Keep a model file at each of the supplied save
      steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    save_backward_graph: bool, if True, save backward graph to file too.
    has_weights: bool, whether weights are included in the inputs.
    nontrainable_param_map: dict, mapping from model nontrainable parameter
      names to control names in PolicySchedule.
    mask_id: id to mask out (None by default).
    metrics: optionally override the default metrics dictionary.

  Returns:
    trax.TrainerState
  """
    n_devices = num_devices()
    # TODO(lukaszkaiser): remove has_weights and mask_id later (configure loss).
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule,
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            save_steps=save_steps,
                            has_weights=has_weights,
                            nontrainable_param_map=nontrainable_param_map,
                            metrics=metrics,
                            mask_id=mask_id)

    epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    for epoch_steps in epochs(train_steps, trainer.step, epoch_steps):
        trainer.train_epoch(epoch_steps, eval_steps)

        # Update nontrainable parameters with new history
        trainer.update_nontrainable_params()

        # Bookkeeping we do at the first step
        if trainer.step == 1:
            # Save computation graph (single-device only for now)
            if (save_graphs and backend.get_name() == 'jax'):
                trainer.save_computation_graphs(save_backward_graph)

            # Save Gin config
            trainer.save_gin()

    trainer.log_step('Training done')
    return trainer.state