def train_for_one_epoch( dataset_source: dataset_source_lib.DatasetSource, optimizer: flax.optim.Optimizer, state: flax.nn.Collection, prng_key: jnp.ndarray, pmapped_train_step: _TrainStep, pmapped_update_ema: Optional[_EMAUpdateStep], moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage], summary_writer: tensorboard.SummaryWriter ) -> Tuple[flax.optim.Optimizer, flax.nn.Collection, Optional[efficientnet_optim.ExponentialMovingAverage]]: """Trains the model for one epoch. Args: dataset_source: Container for the training dataset. optimizer: The optimizer targeting the model to train. state: Current state associated with the model (contains the batch norm MA). prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual dropout mask). Is not used for shuffling the dataset. pmapped_train_step: A pmapped version of the `train_step` function (see its documentation for more details). pmapped_update_ema: Function to update the parameter moving average. Can be None if we don't use EMA. moving_averages: Parameters moving average if used. summary_writer: A Tensorboard SummaryWriter to use to log metrics. Returns: The updated optimizer (with the associated updated model), state and PRNG key. """ start_time = time.time() cnt = 0 train_metrics = [] for batch in dataset_source.get_train(use_augmentations=True): # Generate a PRNG key that will be rolled into the batch. step_key = jax.random.fold_in(prng_key, optimizer.state.step[0]) # Load and shard the TF batch. batch = tensorflow_to_numpy(batch) batch = shard_batch(batch) # Shard the step PRNG key. sharded_keys = common_utils.shard_prng_key(step_key) optimizer, state, metrics, lr = pmapped_train_step( optimizer, state, batch, sharded_keys) cnt += 1 if moving_averages is not None: moving_averages = pmapped_update_ema(optimizer, state, moving_averages) train_metrics.append(metrics) train_metrics = common_utils.get_metrics(train_metrics) # Get training epoch summary for logging. train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) train_summary['learning_rate'] = lr[0] current_step = int(optimizer.state.step[0]) info = 'Whole training step done in {} ({} steps)'.format( time.time()-start_time, cnt) logging.info(info) for metric_name, metric_value in train_summary.items(): summary_writer.scalar(metric_name, metric_value, current_step) summary_writer.flush() return optimizer, state, moving_averages
def process_iterator(tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f'{step:08d}' / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id) if tag == 'test': test_rng = random.PRNGKey(step) shape = batch['origins'][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info('\tUsing appearance_id = %d', appearance_id) metadata['appearance'] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info('\tUsing warp_id = %d', warp_id) metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info('\tUsing camera_id = %d', camera_id) metadata['camera'] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) if datasource.use_time: timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0) logging.info('\tUsing time = %d', timestamp) metadata['time'] = jnp.full( shape, fill_value=timestamp, dtype=jnp.uint32) batch['metadata'] = metadata stats = process_batch(batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource) if jax.process_index() == 0: for k, v in stats.items(): meters[k].update(v) if jax.process_index() == 0: for meter_name, meter in meters.items(): summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}', value=meter.reduce('mean'), step=step)
def test_summarywriter_scalar(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) # Write the scalar and check if the event exists and check data. float_value = 99.1232 summary_writer.scalar(tag='scalar_test', value=float_value, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'scalar_test') self.assertTrue(onp.allclose( tensor_util.make_ndarray(summary_value.tensor).item(), float_value))
def local_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): optimizer_def = optim.Adam() optimizer = optimizer_def.create(init_params) optimizer = util.maybe_load_checkpoint( logdir, optimizer, clobber_checkpoint=clobber_checkpoint) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) def train_step(optimizer, key): loss_val, loss_grad = jax.value_and_grad(loss_fn, argnums=0)(optimizer.target, key) new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return loss_val, new_optimizer train_step = jit(train_step) sw = SummaryWriter(logdir) start = timeit.default_timer() first_step = optimizer.state.step for t in range(optimizer.state.step, num_steps): if t % checkpoint_every == 0 and t != first_step: checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) key, subkey = jax.random.split(key) try: loss_val, new_optimizer = train_step(optimizer, subkey) except FloatingPointError as e: print("Exception on step %d" % t) print(e) traceback.print_exc() checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) print("key ", subkey) sys.stdout.flush() sys.exit(1) optimizer = new_optimizer if t % summarize_every == 0: key, subkey = jax.random.split(key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def parallel_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): loss_fn = jax.jit(loss_fn) optimizer_def = optim.Adam() local_optimizer = optimizer_def.create(init_params) local_optimizer = util.maybe_load_checkpoint( logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint) first_step = local_optimizer.state.step repl_optimizer = jax_utils.replicate(local_optimizer) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) @functools.partial(jax.pmap, axis_name="batch") def train_step(optimizer, key): key, subkey = jax.random.split(key) loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key) loss_grad = jax.lax.pmean(loss_grad, "batch") new_optimizer = optimizer.apply_gradient(loss_grad, learning_rate=lr_fn( optimizer.state.step)) return new_optimizer, subkey sw = SummaryWriter(logdir) repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange( jax.local_device_count())) start = timeit.default_timer() for t in range(first_step, num_steps): if t % checkpoint_every == 0 and t != first_step: optimizer = jax_utils.unreplicate(repl_optimizer) checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) repl_optimizer, repl_key = train_step(repl_optimizer, repl_key) if t % summarize_every == 0: key, subkey = jax.random.split(jax_utils.unreplicate(repl_key)) optimizer = jax_utils.unreplicate(repl_optimizer) loss_val = loss_fn(optimizer.target, key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = common_utils.shard(model_inputs.data) metrics = p_eval_step(optimizer.target, model_inputs) eval_metrics.append(metrics) eval_metrics_np = get_metrics(eval_metrics) eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np) eval_normalizer = eval_metrics_np.pop("normalizer") eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) # Update progress bar epochs.desc = ( f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})" ) if wandb_args.wandb_user_name is not None: wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()}) # Save metrics if has_tensorboard and jax.host_id() == 0: for name, value in eval_summary.items(): summary_writer.scalar(name, value, epoch)
def process_iterator( tag: str, item_ids: Sequence[str], iterator, rng: types.PRNGKey, state: model_utils.TrainState, step: int, render_fn: Any, summary_writer: tensorboard.SummaryWriter, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process a dataset iterator and compute metrics.""" save_dir = save_dir / f"{step:08d}" / tag if save_dir else None meters = collections.defaultdict(utils.ValueMeter) for i, (item_id, batch) in enumerate(zip(item_ids, iterator)): logging.info("[%s:%d/%d] Processing %s ", tag, i + 1, len(item_ids), item_id) if tag == "test": test_rng = random.PRNGKey(step) shape = batch["origins"][..., :1].shape metadata = {} if datasource.use_appearance_id: appearance_id = random.choice( test_rng, jnp.asarray(datasource.appearance_ids)) logging.info("\tUsing appearance_id = %d", appearance_id) metadata["appearance"] = jnp.full(shape, fill_value=appearance_id, dtype=jnp.uint32) if datasource.use_warp_id: warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids)) logging.info("\tUsing warp_id = %d", warp_id) metadata["warp"] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32) if datasource.use_camera_id: camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids)) logging.info("\tUsing camera_id = %d", camera_id) metadata["camera"] = jnp.full(shape, fill_value=camera_id, dtype=jnp.uint32) batch["metadata"] = metadata stats = process_batch( batch=batch, rng=rng, state=state, tag=tag, item_id=item_id, step=step, render_fn=render_fn, summary_writer=summary_writer, save_dir=save_dir, datasource=datasource, ) if jax.host_id() == 0: for k, v in stats.items(): meters[k].update(v) if jax.host_id() == 0: for meter_name, meter in meters.items(): summary_writer.scalar( tag=f"metrics-eval/{meter_name}/{tag}", value=meter.reduce("mean"), step=step, )
def _log_to_tensorboard(writer: tensorboard.SummaryWriter, state: model_utils.TrainState, scalar_params: training.ScalarParams, stats: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]], time_dict: Dict[str, jnp.ndarray]): """Log statistics to Tensorboard.""" step = int(state.optimizer.state.step) writer.scalar('params/learning_rate', scalar_params.learning_rate, step) writer.scalar('params/warp_alpha', state.warp_alpha, step) writer.scalar('params/time_alpha', state.time_alpha, step) writer.scalar('params/elastic_loss/weight', scalar_params.elastic_loss_weight, step) # pmean is applied in train_step so just take the item. for branch in {'coarse', 'fine'}: if branch not in stats: continue for stat_key, stat_value in stats[branch].items(): writer.scalar(f'{stat_key}/{branch}', stat_value, step) if 'background_loss' in stats: writer.scalar('loss/background', stats['background_loss'], step) for k, v in time_dict.items(): writer.scalar(f'time/{k}', v, step)
def _log_to_tensorboard( writer: tensorboard.SummaryWriter, state: model_utils.TrainState, scalar_params: training.ScalarParams, stats: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]], time_dict: Dict[str, jnp.ndarray], ): """Log statistics to Tensorboard.""" step = int(state.optimizer.state.step) writer.scalar("params/learning_rate", scalar_params.learning_rate, step) writer.scalar("params/warp_alpha", state.warp_alpha, step) writer.scalar("params/elastic_loss/weight", scalar_params.elastic_loss_weight, step) # pmean is applied in train_step so just take the item. for branch in {"coarse", "fine"}: if branch not in stats: continue for stat_key, stat_value in stats[branch].items(): writer.scalar(f"{stat_key}/{branch}", stat_value, step) if "background_loss" in stats: writer.scalar("losses/background", stats["background_loss"], step) params = state.optimizer.target["model"] if "appearance_encoder" in params: embeddings = params["appearance_encoder"]["embed"]["embedding"] writer.histogram("appearance_embedding", embeddings, step) if "camera_encoder" in params: embeddings = params["camera_encoder"]["embed"]["embedding"] writer.histogram("camera_embedding", embeddings, step) if "warp_field" in params: embeddings = params["warp_field"]["metadata_encoder"]["embed"][ "embedding"] writer.histogram("warp_embedding", embeddings, step) for k, v in time_dict.items(): writer.scalar(f"time/{k}", v, step)
def _log_to_tensorboard(writer: tensorboard.SummaryWriter, state: model_utils.TrainState, scalar_params: training.ScalarParams, stats: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]], time_dict: Dict[str, jnp.ndarray]): """Log statistics to Tensorboard.""" step = int(state.optimizer.state.step) writer.scalar('params/learning_rate', scalar_params.learning_rate, step) writer.scalar('params/warp_alpha', state.warp_alpha, step) writer.scalar('params/elastic_loss/weight', scalar_params.elastic_loss_weight, step) # pmean is applied in train_step so just take the item. for branch in {'coarse', 'fine'}: if branch not in stats: continue for stat_key, stat_value in stats[branch].items(): writer.scalar(f'{stat_key}/{branch}', stat_value, step) if 'background_loss' in stats: writer.scalar('losses/background', stats['background_loss'], step) params = state.optimizer.target['model'] if 'appearance_encoder' in params: embeddings = params['appearance_encoder']['embed']['embedding'] writer.histogram('appearance_embedding', embeddings, step) if 'camera_encoder' in params: embeddings = params['camera_encoder']['embed']['embedding'] writer.histogram('camera_embedding', embeddings, step) if 'warp_field' in params: embeddings = params['warp_field']['metadata_encoder']['embed'][ 'embedding'] writer.histogram('warp_embedding', embeddings, step) for k, v in time_dict.items(): writer.scalar(f'time/{k}', v, step)