def train_for_one_epoch( dataset_source: dataset_source_lib.DatasetSource, optimizer: flax.optim.Optimizer, state: flax.nn.Collection, prng_key: jnp.ndarray, pmapped_train_step: _TrainStep, pmapped_update_ema: Optional[_EMAUpdateStep], moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage], summary_writer: tensorboard.SummaryWriter ) -> Tuple[flax.optim.Optimizer, flax.nn.Collection, Optional[efficientnet_optim.ExponentialMovingAverage]]: """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). pmapped_update_ema: Function to update the parameter moving average. Can be None if we don't use EMA. moving_averages: Parameters moving average if used. summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ start_time = time.time() cnt = 0 train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key = jax.random.fold_in(prng_key, optimizer.state.step[0]) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) cnt += 1 if moving_averages is not None: moving_averages = pmapped_update_ema(optimizer, state, moving_averages) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) info = 'Whole training step done in {} ({} steps)'.format( time.time()-start_time, cnt) logging.info(info) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, moving_averages
def local_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): optimizer_def = optim.Adam() optimizer = optimizer_def.create(init_params) optimizer = util.maybe_load_checkpoint( logdir, optimizer, clobber_checkpoint=clobber_checkpoint) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) def train_step(optimizer, key): loss_val, loss_grad = jax.value_and_grad(loss_fn, argnums=0)(optimizer.target, key) new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return loss_val, new_optimizer train_step = jit(train_step) sw = SummaryWriter(logdir) start = timeit.default_timer() first_step = optimizer.state.step for t in range(optimizer.state.step, num_steps): if t % checkpoint_every == 0 and t != first_step: checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) key, subkey = jax.random.split(key) try: loss_val, new_optimizer = train_step(optimizer, subkey) except FloatingPointError as e: print("Exception on step %d" % t) print(e) traceback.print_exc() checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) print("key ", subkey) sys.stdout.flush() sys.exit(1) optimizer = new_optimizer if t % summarize_every == 0: key, subkey = jax.random.split(key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def parallel_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): loss_fn = jax.jit(loss_fn) optimizer_def = optim.Adam() local_optimizer = optimizer_def.create(init_params) local_optimizer = util.maybe_load_checkpoint( logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint) first_step = local_optimizer.state.step repl_optimizer = jax_utils.replicate(local_optimizer) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) @functools.partial(jax.pmap, axis_name="batch") def train_step(optimizer, key): key, subkey = jax.random.split(key) loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key) loss_grad = jax.lax.pmean(loss_grad, "batch") new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return new_optimizer, subkey sw = SummaryWriter(logdir) repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange( jax.local_device_count())) start = timeit.default_timer() for t in range(first_step, num_steps): if t % checkpoint_every == 0 and t != first_step: optimizer = jax_utils.unreplicate(repl_optimizer) checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) repl_optimizer, repl_key = train_step(repl_optimizer, repl_key) if t % summarize_every == 0: key, subkey = jax.random.split(jax_utils.unreplicate(repl_key)) optimizer = jax_utils.unreplicate(repl_optimizer) loss_val = loss_fn(optimizer.target, key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def test_summarywriter_flush_after_close(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) summary_writer.close() with self.assertRaises(AttributeError): summary_writer.flush()