def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if backend.get_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = backend.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and backend.get_name() == 'jax': raise ValueError('JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host} ) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout} ) set_tf_allow_float64(FLAGS.tf_allow_float64) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.use_tpu and backend.get_name() == 'tf': worker_cpu = tf_init_tpu() with tf.device(worker_cpu): if trainer_lib.num_devices() == 1: # TF's device priority is GPU > CPU > TPU, so we need to explicitly make # the TPU core the default device here. with tf.device('/device:TPU:0'): trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout}) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') trainer_lib.train(output_dir=output_dir)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): embs = [] for ax_emb in weights: ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) emb = np.concatenate(embs, -1) if self._mode == 'predict': assert self._dropout == 0.0 emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) return inputs + emb[:, state, :][:, None, :], state + 1 elif self._dropout == 0: return inputs + np.reshape(emb, inputs.shape), state else: noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, np.full((), keep_prob, dtype=inputs.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + np.reshape(emb * multiplier, inputs.shape), state
def forward_and_backward(self, inputs, ct, state, new_state, rng=None, **kwargs): assert backend.get_name() == 'jax', ( 'JAX backend is required to use forward_and_backward.') if ct is not None and new_state is not tl.EMPTY_STATE: recovered_rng = new_state is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1]) is_same = is_same.astype(np.float32) # Divides by zero if rngs are not the same, which results in NaNs. inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same) def _do_forward(x): # pylint: disable=invalid-name res, _ = self.forward_with_state(x, state=state, rng=rng, **kwargs) return res output, vjpfun = jax.vjp(_do_forward, inputs) return output, vjpfun(ct)[0]
def f(x): if n > 1 and backend.get_name() == 'jax': return _multi_device_put(x) elif n > 1: return np.broadcast_to(x, (n, ) + x.shape) else: return x
def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name """Make a n+1 dim one-hot array from n dim int-categorical array.""" arange_size = np.arange(size) if backend.get_name() == 'jax': # Work around a jax broadcasting issue. arange_size = jax.lax.tie_in(x, arange_size) return np.array(x[..., np.newaxis] == arange_size, dtype)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[:, state, :], 1), state + 1)
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): del weights q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if backend.get_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def _do_custom_gradients(self, x, weights, state, **kwargs): """Calls this layer for a forward pass, but with custom gradients.""" assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # TODO(wangpeng): JAX doesn't support custom grads for functions with # auxiliary output yet (https://github.com/google/jax/issues/844). Will # remove the constraints on state below when this feature is added to # JAX. assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial start state. Got %s' % str(state)) def check_end_state(output_state): output, state = output_state assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial end state. Got %s' % str(state)) return output # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def _do_forward(y, weights): return check_end_state( self.forward_with_state(y, weights=weights, state=state, **kwargs)) # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" stash = None if Layer._STASH_IN is None: Layer._STASH_IN = stash = {} output = check_end_state( self.forward_with_state(y, weights=weights, state=state, **kwargs)) if stash is not None: Layer._STASH_IN = None def vjpfun(grad): assert Layer._STASH_OUT is None Layer._STASH_OUT = stash res = self.backward(y, output, grad, weights, state, **kwargs) Layer._STASH_OUT = None return res return output, vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) return _do_forward(x, weights), state
def _maybe_replicate(self, x): if self._n_devices > 1: if backend.get_name() == 'jax': return multi_device_put(x) else: return np.broadcast_to(x, (self._n_devices,) + x.shape) else: return x
def _jax_and_tf_configure_for_devices(): if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU')
def forward_and_backward(self, inputs, ct, state, new_state, **kwargs): assert backend.get_name() == 'jax', ( 'JAX backend is required to use forward_and_backward.') # Simultaneous forward pass and backprop through the attention mechanism. def _do_forward(x): # pylint: disable=invalid-name res, _ = self.forward_with_state(x, state=state, **kwargs) return res output, vjpfun = jax.vjp(_do_forward, inputs) return output, vjpfun(ct)[0]
def _save_replicated(opt_state, step, history, model_state, n_devices, output_dir, keep): """Saves trainer state but given a possibly replicated opt_state.""" if n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*layers.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device to # the host in parallel, which is particularly important for cloud TPU. if backend.get_name() == 'jax': opt_state = jax.device_get(opt_state) save_trainer_state( TrainerState(opt_state=opt_state, step=step, history=history, model_state=model_state), output_dir, keep=keep)
def __init__(self, loop_stride, dropout, mode, share_qk=False, hard_k=0): assert backend.get_name() == 'jax', ( 'JAX backend is required to use MemoryEfficientCausalAttention.') super(MemoryEfficientCausalAttention, self).__init__() self._loop_stride = loop_stride if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if mode == 'train': self.dropout = dropout else: self.dropout = None self._share_qk = share_qk self._hard_k = hard_k
def main(_): logging.set_verbosity(FLAGS.log_level) _tf_setup_from_flags() _gin_parse_configs() _jax_and_tf_configure_for_devices() output_dir = _output_dir_or_default() if FLAGS.use_tpu and backend.get_name() == 'tf': _train_using_tf(output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" assert backend.get_name() == 'jax', ( 'JAX backend is required to use the predict mode.') for x in inputs: assert x.shape[1] == 1, ( 'In predict mode the input sequence must be of length 1.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, index) = state ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :]) vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :]) mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1) return (ks, vs, mask, index + 1)
def save_state(self, keep): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*backend.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if backend.get_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def _do_custom_gradients(self, x, weights, state, **kwargs): """Calls this layer for a forward pass, but with custom gradients.""" assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def _do_forward(y, weights): res = self.forward_with_state(y, weights=weights, state=state, **kwargs) return res # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" output, new_state = self.forward_with_state(y, weights=weights, state=state, **kwargs) def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, **kwargs) return res return (output, state), vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) output, state = _do_forward(x, weights) state = jax.lax.stop_gradient(state) return output, state
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): if backend.get_name() == 'jax': self._host_id = jax.host_id() self._host_count = jax.host_count() else: self._host_id = 0 self._host_count = 1 self._is_chief = (self._host_id == 0) if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save_checkpoints = should_save_checkpoints self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = _METRICS if metrics is None else metrics loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = backend.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and backend.get_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) self._n_devices = n_devices # Simple differential seeding of RNG across hosts by host_id and time. if random_seed is None and self._host_count > 1: _, random_seed = divmod( int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32) rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit( new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2), (1, )) dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access m_weights, m_state = metrics_in_parallel.init(dummy_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def train(output_dir, model=gin.REQUIRED, loss_fn=tl.CrossEntropyLossScalar, inputs=trax_inputs.inputs, optimizer=trax_opt.Adafactor, lr_schedule=lr.MultifactorSchedule, trainer_class=Trainer, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, random_seed=None, save_graphs=True, save_backward_graph=False, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). trainer_class: The trainer class to use. train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. has_weights: bool, whether weights are included in the inputs. nontrainable_param_map: dict, mapping from model nontrainable parameter names to control names in PolicySchedule. mask_id: id to mask out (None by default). metrics: optionally override the default metrics dictionary. Returns: trax.TrainerState """ n_devices = num_devices() # TODO(lukaszkaiser): remove has_weights and mask_id later (configure loss). trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=random_seed, n_devices=n_devices, save_steps=save_steps, has_weights=has_weights, nontrainable_param_map=nontrainable_param_map, metrics=metrics, mask_id=mask_id) epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) trainer.log_step('Starting training using %d devices' % trainer.n_devices) trainer.print_n_weights() for epoch_steps in epochs(train_steps, trainer.step, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Update nontrainable parameters with new history trainer.update_nontrainable_params() # Bookkeeping we do at the first step if trainer.step == 1: # Save computation graph (single-device only for now) if (save_graphs and backend.get_name() == 'jax'): trainer.save_computation_graphs(save_backward_graph) # Save Gin config trainer.save_gin() trainer.log_step('Training done') return trainer.state