def test_named(self, wait_jax_async_dispatch, mock_time): mock_time.return_value = 0 hook = periodic_actions.ReportProgress(every_steps=1, every_secs=None, num_train_steps=10) def _wait(): # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1) hook._executor.submit(lambda: None).result() self.assertFalse(hook(1)) # Never triggers on first execution. with hook.timed("test1", wait_jax_async_dispatch): _wait() mock_time.return_value = 1 _wait() with hook.timed("test2", wait_jax_async_dispatch): _wait() mock_time.return_value = 2 _wait() with hook.timed("test1", wait_jax_async_dispatch): _wait() mock_time.return_value = 3 _wait() mock_time.return_value = 4 with self.assertLogs(level="INFO") as logs: self.assertTrue(hook(2)) self.assertEqual(logs.output, [ "INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA: 0m" " (0m : 50.0% test1, 25.0% test2)" ])
def test_named(self, wait_jax_async_dispatch, time_mock): time_mock.return_value = 0 hook = periodic_actions.ReportProgress(every_steps=1, every_secs=None, num_train_steps=10) def _wait(): # Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1) hook._executor.submit(lambda: None).result() hook(1) with hook.timed("test1", wait_jax_async_dispatch): _wait() time_mock.return_value = 1 _wait() with hook.timed("test2", wait_jax_async_dispatch): _wait() time_mock.return_value = 2 _wait() with hook.timed("test1", wait_jax_async_dispatch): _wait() time_mock.return_value = 3 _wait() time_mock.return_value = 4 with self.assertLogs(level="INFO") as logs: hook(2) self.assertEqual(logs.output, [ "INFO:absl:Setting work unit notes: 20.0% @2, 0.2 steps/s, ETA: 1 min" " (0 min : 50.0% test1, 25.0% test2)" ])
def test_called_every_step(self): hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10) t = time.time() with self.assertRaisesRegex( ValueError, "PeriodicAction must be called after every step"): hook(1, t) hook(11, t) # Raises exception.
def test_called_every_step(self): hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10) t = time.time() with self.assertRaisesRegex( ValueError, "EveryNHook must be called after every step"): hook(1, t) # Skipping step 2. hook(11, t)
def test_without_num_train_steps(self): report = periodic_actions.ReportProgress(every_steps=2) t = time.time() with self.assertLogs(level="INFO") as logs: self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. self.assertEqual(logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"])
def test_unknown_cardinality(self): report = periodic_actions.ReportProgress( every_steps=2, num_train_steps=tf.data.UNKNOWN_CARDINALITY) t = time.time() with self.assertLogs(level="INFO") as logs: self.assertFalse(report(1, t)) self.assertTrue(report(2, t + 0.12)) # We did 1 step in 0.12s => 8.333 steps/s. self.assertEqual(logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"])
def test_every_steps(self): hook = periodic_actions.ReportProgress(every_steps=4, every_secs=None, num_train_steps=10) t = time.time() with self.assertLogs(level="INFO") as logs: hook(1, t) t += 0.11 hook(2, t) t += 0.13 hook(3, t) t += 0.12 hook(4, t) # We did 1 step every 0.12s => 8.333 steps/s. self.assertEqual(logs.output, [ "INFO:absl:Setting work unit notes: 40.0% @4, 8.3 steps/s, ETA: 0 min" ])
def test_every_secs(self): hook = periodic_actions.ReportProgress(every_steps=None, every_secs=0.3, num_train_steps=10) t = time.time() with self.assertLogs(level="INFO") as logs: self.assertFalse(hook(1, t)) t += 0.11 self.assertFalse(hook(2, t)) t += 0.13 self.assertFalse(hook(3, t)) t += 0.12 self.assertTrue(hook(4, t)) # We did 1 step every 0.12s => 8.333 steps/s. self.assertEqual(logs.output, [ "INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10), ETA: 0m" ])
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, eval_ds = datasets.create_dataset(config) example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weight=True) writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.process_index() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [ k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_util.train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(train_util.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_util.initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(train_util.predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps if FLAGS.eval_only: total_steps = start_step + 1 best_eval_loss = 1000 curr_eval_loss = 1000 eval_loss_history = [] last_eval_step = 0 do_resample_data = False gradual_selection_size = FLAGS.data_selection_size dynamic_eval_freq = FLAGS.eval_frequency with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 # Resample training data for gradual FT if do_resample_data: # resample data do_resample_data = False gradual_selection_size *= .7 dynamic_eval_freq = int(gradual_selection_size / 1000 / 4) train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=int(gradual_selection_size), pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer) train_iter = iter(train_ds) # Shard data to devices and do a training step. if not FLAGS.eval_only: logging.info('Doing Training.') with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % dynamic_eval_freq == 0 or is_last_step: if not FLAGS.eval_only: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] if FLAGS.eval_only: p_eval_per_pos_step = jax.pmap(functools.partial( train_util.eval_per_pos_step, config=eval_config), axis_name='batch') # Get per example loss loss_filename = FLAGS.model_dir + '/test_losses.csv' train_util.write_per_example_losses( p_eval_step=p_eval_per_pos_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps, loss_filename=loss_filename) else: with report_progress.timed('eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] eval_loss_history.append(curr_eval_loss) if len(eval_loss_history) > 1: improvement_rate = 0.000004 orig_loss = eval_loss_history[-2] true_improvement = orig_loss - curr_eval_loss expected_improvement = ( step - last_eval_step) * improvement_rate # percent_change = (orig_loss - curr_eval_loss) / orig_loss # percent_change *= 100 if true_improvement < expected_improvement: # percent_change<.1: do_resample_data = True last_eval_step = step writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): decode_file = FLAGS.model_dir + '/decodes.csv' exemplars, bleu_score = train_util.translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length, num_eval_steps=FLAGS.num_eval_steps, decode_file=decode_file if FLAGS.eval_only else '') writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.process_index( ) == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def train_and_evaluate(self, workdir): """Runs a training and evaluation loop. Args: workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) config = self.config substeps = config.training.substeps # Learning rate schedule. num_train_steps = config.training.num_train_steps logging.info('num_train_steps=%d', num_train_steps) # Get train state state = self._train_state # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Distribute training. state = flax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 0: writer.write_hparams(dict(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] step = initial_step with metric_writers.ensure_flushes(writer): while step < num_train_steps: # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step + substeps >= num_train_steps with jax.profiler.StepTraceAnnotation('train', step_num=step): inputs = jax.tree_map(np.asarray, next(self._train_iter)) state, outputs = self._update_func(state, inputs) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) new_step = int(state.step[0]) assert new_step == step + substeps step = new_step is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step if step % config.logs.log_loss_every_steps == 0 and not is_eval: def avg_over_substeps(x): assert x.shape[0] == substeps return float(x.mean(axis=0)) # Extract scalars and images. outputs = flax_utils.unreplicate(outputs) outputs = jax.tree_map(avg_over_substeps, outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if is_eval: with report_progress.timed('eval_full'): outputs = self._eval_epoch(params=state.ema_params) outputs = flax_utils.unreplicate(outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if step % config.logs.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Finishing training at step %d', num_train_steps)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id() == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ logging.info('Starting training at %s', workdir) tf.io.gfile.makedirs(workdir) if jax.process_index() == 0: with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f: json.dump(config.to_dict(), f, indent=2) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) train_ds, eval_ds = input_pipeline.create_datasets(config.dataset, data_rng) train_iter = iter(train_ds) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = train_ds.cardinality().numpy() steps_per_epoch = num_train_steps // config.dataset.num_epochs logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps, steps_per_epoch) learning_rate_fn = functools.partial( train_utils.get_learning_rate, base_learning_rate=config.learning_rate, num_train_steps=num_train_steps, schedule_type=config.learning_rate_schedule, warmup_proportion=config.warmup_proportion, step_boundaries=config.learning_rate_step_boundaries) # Initialize model. inputs = train_utils.get_init_inputs(train_ds) rng, model_rng = jax.random.split(rng) eval_config = models.TransformerConfig(**config.model.to_dict()) train_config = eval_config.replace(deterministic=False) model = models.Model(eval_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, grad_clip=config.grad_clip), axis_name='batch', donate_argnums=(0, )) p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name='batch') writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(train_utils.flatten_config(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile( num_profile_steps=config.num_profile_steps, logdir=workdir) ] rng, train_rngs = jax.random.split(rng) train_rngs = jax.random.fold_in(train_rngs, jax.process_index()) train_rngs = jax.random.split(train_rngs, jax.local_device_count()) train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): is_last_step = step == num_train_steps with jax.profiler.StepTraceContext('train', step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics = p_train_step(batch=batch, rng=train_rngs, state=state) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) if config.log_loss_every_steps > 0 and ( step % config.log_loss_every_steps == 0 or is_last_step): train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() train_summary = train_utils.metrics_summary( train_metrics, 'train') train_summary['learning_rate'] = lr writer.write_scalars(step, train_summary) train_metrics = [] if config.eval_every_steps > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('eval'): eval_summary = evaluate(p_eval_step, state, eval_ds, config.num_eval_steps) writer.write_scalars(step, eval_summary) if config.checkpoint_every_steps > 0 and ( step % config.checkpoint_every_steps == 0 or is_last_step): with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Checkpoint saved to %s', checkpoint_dir) logging.info('Finishing training at step %d', num_train_steps)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs training interleaved with evaluation.""" # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') ds_train, ds_test = input_pipeline.get_datasets(config) batch = next(iter(ds_train)) logging.info(ds_train) logging.info(ds_test) # Build VisionTransformer architecture model_cls = {'ViT': models.VisionTransformer, 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] model = model_cls(num_classes=dataset_info['num_classes'], **config.model) def init_model(): return model.init( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), train=False) # Use JIT to make sure params reside in CPU memory. variables = jax.jit(init_model, backend='cpu')() model_or_filename = config.get('model_or_filename') if model_or_filename: # Loading model from repo published with "How to train your ViT? Data, # Augmentation, and Regularization in Vision Transformers" paper. # https://arxiv.org/abs/2106.10270 if '-' in model_or_filename: filename = model_or_filename else: # Select best checkpoint from i21k pretraining by final upstream # validation accuracy. df = checkpoint.get_augreg_df(directory=config.pretrained_dir) sel = df.filename.apply( lambda filename: filename.split('-')[0] == model_or_filename) best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] filename = best.filename logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', filename, model_or_filename, best.final_val) pretrained_path = os.path.join(config.pretrained_dir, f'{config.model.name}.npz') else: # ViT / Mixer papers filename = config.model.name pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') if not tf.io.gfile.exists(pretrained_path): raise ValueError( f'Could not find "{pretrained_path}" - you can download models from ' '"gs://vit_models/imagenet21k" or directly set ' '--config.pretrained_dir="gs://vit_models/imagenet21k".') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=variables['params'], model_config=config.model) total_steps = config.total_steps lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, config.decay_type, config.warmup_steps) update_fn_repl = make_update_fn( apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn) infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=config.optim_dtype, grad_norm_clip=config.grad_norm_clip).create(params) initial_step = 1 opt, initial_step = flax_checkpoints.restore_checkpoint( workdir, (opt, initial_step)) logging.info('Will start/continue training at initial_step=%d', initial_step) opt_repl = flax.jax_utils.replicate(opt) # Delete references to the objects that are not needed anymore del opt del params # Prepare the learning-rate and pre-fetch it to device to avoid delays. update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) # Setup metric writer & hooks. writer = metric_writers.create_default_writer(workdir, asynchronous=False) writer.write_hparams(config.to_dict()) hooks = [ periodic_actions.Profile(logdir=workdir), periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer), ] # Run training loop logging.info('Starting training loop; initial compile can take a while...') t0 = lt0 = time.time() lstep = initial_step for step, batch in zip( range(initial_step, total_steps + 1), input_pipeline.prefetch(ds_train, config.prefetch)): with jax.profiler.StepTraceContext('train', step_num=step): opt_repl, loss_repl, update_rng_repl = update_fn_repl( opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) for hook in hooks: hook(step) if step == initial_step: logging.info('First step took %.1f seconds.', time.time() - t0) t0 = time.time() lt0, lstep = time.time(), step # Report training metrics if config.progress_every and step % config.progress_every == 0: img_sec_core_train = (config.batch * (step - lstep) / (time.time() - lt0)) / jax.device_count() lt0, lstep = time.time(), step writer.write_scalars( step, dict( train_loss=float(flax.jax_utils.unreplicate(loss_repl)), img_sec_core_train=img_sec_core_train)) done = step / total_steps logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-format-interpolation f'img/sec/core: {img_sec_core_train:.1f}, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') # Run evaluation if ((config.eval_every and step % config.eval_every == 0) or (step == total_steps)): accuracies = [] lt0 = time.time() for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): logits = infer_fn_repl( dict(params=opt_repl.target), test_batch['image']) accuracies.append( (np.argmax(logits, axis=-1) == np.argmax(test_batch['label'], axis=-1)).mean()) accuracy_test = np.mean(accuracies) img_sec_core_test = ( config.batch_eval * ds_test.cardinality().numpy() / (time.time() - lt0) / jax.device_count()) lt0 = time.time() lr = float(lr_fn(step)) logging.info(f'Step: {step} ' # pylint: disable=logging-format-interpolation f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}, ' f'img/sec/core: {img_sec_core_test:.1f}') writer.write_scalars( step, dict( accuracy_test=accuracy_test, lr=lr, img_sec_core_test=img_sec_core_test)) # Store checkpoint. if ((config.checkpoint_every and step % config.eval_every == 0) or step == total_steps): checkpoint_path = flax_checkpoints.save_checkpoint( workdir, (flax.jax_utils.unreplicate(opt_repl), step), step) logging.info('Stored checkpoint at step %d to "%s"', step, checkpoint_path) return flax.jax_utils.unreplicate(opt_repl)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if config.batch_size % n_devices: raise ValueError( "Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 1: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5) ] metrics_all = [] with metric_writers.ensure_flushes(writer): for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) metrics_all = [] # Eval Metrics logging.info("Gathering evaluation metrics.") eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) eval_summary = {"eval_" + k: v for k, v in eval_summary.items()} writer.write_scalars(step, eval_summary) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info( "Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars})
def train_and_evaluate(config, workdir): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The train state (which includes the `.params`). """ # Seed for reproducibility. rng = jax.random.PRNGKey(config.rng_seed) # Set up logging. summary_writer = metric_writers.create_default_writer(workdir) summary_writer.write_hparams(dict(config)) # Get datasets. rng, dataset_rng = jax.random.split(rng) dataset = input_pipeline.get_dataset(config, dataset_rng) graph, labels, masks = jax.tree_map(jnp.asarray, dataset) labels = jax.nn.one_hot(labels, config.num_classes) train_mask = masks['train'] train_indices = jnp.where(train_mask)[0] train_labels = labels[train_indices] num_training_nodes = len(train_indices) # Get subgraphs. if config.differentially_private_training: graph = jax.tree_map(np.asarray, graph) subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to) graph = jax.tree_map(jnp.asarray, graph) # We only need the subgraphs for training nodes. train_subgraphs = subgraphs[train_indices] del subgraphs else: train_subgraphs = None # Initialize privacy accountant. training_privacy_accountant = privacy_accountants.get_training_privacy_accountant( config, num_training_nodes, compute_max_terms_per_node(config)) # Construct and initialize model. rng, init_rng = jax.random.split(rng) estimation_indices = get_estimation_indices(train_indices, config) state = create_train_state(init_rng, config, graph, train_labels, train_subgraphs, estimation_indices) # Set up checkpointing of the model. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Log overview of parameters. parameter_overview.log_parameter_overview(state.params) # Log metrics after initialization. logits = compute_logits(state, graph) metrics_after_init = compute_metrics(logits, labels, masks) metrics_after_init['epsilon'] = 0 log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init') # Train model. rng, train_rng = jax.random.split(rng) max_training_epsilon = get_max_training_epsilon(config) # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_training_steps, writer=summary_writer) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] for step in range(initial_step, config.num_training_steps): # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): # Sample batch. step_rng = jax.random.fold_in(train_rng, step) indices = jax.random.choice(step_rng, num_training_nodes, (config.batch_size, )) # Compute gradients. if config.differentially_private_training: grads = compute_updates_for_dp(state, graph, train_labels, train_subgraphs, indices, config.adjacency_normalization) else: grads = compute_updates(state, graph, train_labels, indices) # Update parameters. state = update_model(state, grads) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 10, step) for hook in hooks: hook(step) # Evaluate, if required. is_last_step = (step == config.num_training_steps - 1) if step % config.evaluate_every_steps == 0 or is_last_step: with report_progress.timed('eval'): # Check if privacy budget exhausted. training_epsilon = training_privacy_accountant(step + 1) if max_training_epsilon is not None and training_epsilon >= max_training_epsilon: break # Compute metrics. logits = compute_logits(state, graph) metrics_during_training = compute_metrics( logits, labels, masks) metrics_during_training['epsilon'] = training_epsilon log_metrics(step, metrics_during_training, summary_writer) # Checkpoint, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(state) return state
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning( "%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState(step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer(batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like( previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError( "A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError( "A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats[ "patch_representations"] writer.write_histograms( step, {"patch_representations": patch_representations}) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ is_first_process = jax.process_index() == 0 tf.io.gfile.makedirs(workdir) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( config) config.seq_length = 250 vocab_size = int(encoder.vocab_size()) config.num_classes = vocab_size config.data_shape = (config.seq_length, 1) logging.info('Training with vocab size %d', vocab_size) def decode_tokens(toks): return encoder.detokenize(toks) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) config.per_device_batch_size = config.batch_size // jax.process_count() logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- model, initial_variables = model_setup(init_rng, config) # Instead of passing the optimizer fns directly, we use a fn that returns # the optimizer given a learning rate. def tx_fn(lr): return optax.adamw( lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create( params=initial_variables['params'], tx_fn=tx_fn) # We access model params only from state below via state.params. del initial_variables if config.restore_checkpoints: # Restore unreplicated model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=not is_first_process) if start_step == 0: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_process and start_step == 0: # Dump config file to work dir for easy model loading. config_path = os.path.join(workdir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) print('Using state', type(state)) # Replicate state. state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( factors=config.lr_factors, base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # Compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, clip_grad=config.clip_grad, ema_momentum=config.get('ema_momentum', 0.999)), axis_name='batch', in_axes=(0, 0), donate_argnums=(0,)) p_eval_step = jax.pmap( functools.partial( eval_step, model=model), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of train PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rng = jax.random.fold_in(rng, jax.process_index()) rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5) train_rngs = jax.random.split(rng1, jax.local_device_count()) eval_rngs = jax.random.split(rng2, jax.local_device_count()) test_rngs = jax.random.split(rng3, jax.local_device_count()) del rng, rng1, rng2, rng3 logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_process: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] # Iterator that does epoch-wise indefinite iteration. def iterate_train(train_ds): epoch = 1 while True: msg = f'Starting epoch {epoch}' logging.info(msg) for batch in train_ds: yield batch epoch += 1 train_iter = iterate_train(train_ds) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] with metric_writers.ensure_flushes(writer): step = start_step for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) state, metrics = p_train_step( state, batch, rng=train_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if step > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) # First handle loss terms per step. t_batch = train_metrics.pop('t_batch') nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch') kl_tracker_train.update( t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1)) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-100:] # Keep last 100 items only. # Handle remaining `standard` metrics summary = jax.tree_map(jnp.mean, train_metrics) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=eval_ds, rng=eval_rngs) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) test_results, test_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=test_ds, rng=test_rngs) writer.write_scalars( step, {'test_' + k: v for k, v in test_results.items()}) if step == 1000 or (step > 0 and step % config.detailed_eval_every_steps == 0): if is_first_process: loss_components_path = os.path.join(workdir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) extensive_eval_rngs = extensive_eval( config, extensive_eval_rngs, writer, workdir, model, state, kl_history, test_ds, step, decode_tokens) with report_progress.timed('generate_text'): generate_prediction(sample_rng, config, model, state, writer, decode_tokens, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step > 0 and (step % config.checkpoint_every_steps == 0 or is_last_step)) if config.save_checkpoints and save_checkpoint and is_first_process: with report_progress.timed('checkpoint'): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state), step, overwrite=True)
def train_and_evaluate(config, workdir, strategy): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. strategy: Distribution strategy to use for distributing the model. """ tf.io.gfile.makedirs(workdir) tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0), 2) tf.random.set_seed(tf_rng.numpy()[0]) # Input pipeline. ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets( config, data_rng, strategy=strategy) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = (ds_info.splits["train"].num_examples // config.global_batch_size * config.num_epochs) steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 256 but scale it according to our batch size. base_learning_rate = config.learning_rate * config.global_batch_size / 256.0 # Initialize model. num_classes = ds_info.features["label"].num_classes if config.distill_teacher: do_distill = True teacher_file_list = (config.distill_teacher).split(",") teacher_models = load_teacher_models(teacher_file_list, num_classes, config, strategy) distill_params = {} distill_params["alpha"] = config.distill_alpha distill_params["beta"] = config.distill_fd_beta distill_params["teacher_model"] = TeacherModel(teacher_models, name="teacher") else: do_distill = False distill_params = None state = create_state(config, num_classes=num_classes, strategy=strategy) ckpt_manager = tf.train.CheckpointManager(checkpoint=state, directory=workdir, max_to_keep=5) if ckpt_manager.latest_checkpoint: state.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: logging.info("Initializing from scratch.") initial_step = state.global_step.numpy().item() learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) writer = metric_writers.create_default_writer(workdir) writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): state.model.trainable = True # `step` is a Python integer. `global_step` is a TF variable on the # GPU/TPU devices. is_last_step = step == num_train_steps train_step(state, train_iter, config.weight_decay, learning_rate_fn, do_distill, distill_params, strategy) state.train_metrics.update_state_lr( learning_rate_fn(state.global_step.numpy().item())) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) report_progress(step) if step == initial_step: parameter_overview.log_parameter_overview(state.model) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, state.train_metrics.result()) state.train_metrics.reset_states() state.train_metrics.reset_lr() if step % config.eval_every_steps == 0 or is_last_step: state.model.trainable = False if config.dataset == "imagenet-lt": evaluate(state, val_ds, state.val_metrics, strategy) writer.write_scalars(step, state.val_metrics.result()) logging.info("Num val images %d", state.val_metrics.accuracy.count.numpy()) evaluate(state, test_ds, state.test_metrics, strategy) writer.write_scalars(step, state.test_metrics.result()) logging.info("Num test images %d", state.test_metrics.accuracy.count.numpy()) if step % config.checkpoint_every_steps == 0 or is_last_step: checkpoint_path = ckpt_manager.save(step) logging.info("Saved checkpoint %s", checkpoint_path) logging.info("Finishing training at step %d", step) logging.info("Saving the final weights") file_path = "%s/final_weights" % workdir state.model.save_weights(file_path, save_format="tf")
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) splits = input_pipeline.create_datasets(config, data_rng) num_classes = splits.info.features["label"].num_classes train_iter = iter(splits.train) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = splits.train.cardinality().numpy() steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 32 but scale it according to our batch size. global_batch_size = config.per_device_batch_size * jax.device_count() base_learning_rate = config.learning_rate * global_batch_size / 32.0 learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) # Initialize model. rng, model_rng = jax.random.split(rng) model, state = create_train_state( config, model_rng, input_shape=splits.train.element_spec["input"].shape[1:], num_classes=num_classes) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {"train_iter": train_iter}, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Count number of trainable parameters. This must be done before replicating # the state to avoid double-counting replicated parameters. param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target)) # Distribute training over local devices. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.weight_decay), axis_name=_PMAP_AXIS_NAME) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if initial_step == 1: writer.write_hparams(dict(config)) # Log the number of trainable params. writer.write_scalars(initial_step, {"param_count": param_count}) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics = None with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceContext("train", step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics_update = p_train_step(state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None # When combining train and eval, we do not evaluate while training. if ((step % config.eval_every_steps == 0 or is_last_step) and not config.combine_train_val_and_eval_on_test): with report_progress.timed("eval"): eval_metrics = evaluate(model, state, splits.validation, config.num_eval_steps) writer.write_scalars(step, eval_metrics.compute()) if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed("checkpoint"): ckpt.save(flax_utils.unreplicate(state)) if is_last_step and config.combine_train_val_and_eval_on_test: # Evaluate a single time on the test set when requested. with report_progress.timed("test"): test_metrics = evaluate(model, state, splits.test, config.num_eval_steps) writer.write_scalars(step, test_metrics.compute()) logging.info("Finishing training at step %d", num_train_steps)
def train(config: ml_collections.ConfigDict): """Run training.""" # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() task = task_registry.get_registered_task(config.task_name) start_step = 0 rng = jax.random.PRNGKey(config.seed) model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap( model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) logging.info('*** Missing features: ***') for feature_name in missing: logging.info('\t%s', feature_name) model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } learning_rate_fn = optim_utils.create_learning_rate_scheduler( learning_rate=config.learning_rate, warmup=config.warmup, warmup_steps=config.get('warmup_steps', None), linear_decay=config.linear_decay, max_steps=config.num_train_steps, decay_minimum_factor=config.get('decay_minimum_factor', None), ) if config.weight_decay_exclude is not None: decay_mask = optim_utils.create_dict_mask(initial_variables['params'], config.weight_decay_exclude) else: decay_mask = None tx = optax.adamw( learning_rate=learning_rate_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.999, eps=1e-6, mask=decay_mask) if config.grad_clip is not None: tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip)) ignore_k_nans = config.get('ignore_k_nans') if ignore_k_nans is not None: tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans) loss_fn = task.make_loss_fn(config) train_state = ts.TrainState.create( apply_fn=loss_fn, params=jax_utils.unreplicate(initial_variables['params']), tx=tx, ) # We access model params only from train state. del initial_variables # Restore unreplicated train state from last checkpoint train_state = checkpoints.restore_checkpoint(config.model_dir, train_state) # Grab last step. start_step = int(train_state.step) writer = metric_writers.create_default_writer( config.model_dir, just_logging=jax.process_index() > 0) if start_step == 0: writer.write_hparams(config.to_dict()) dropout_rngs = jax.random.split(rng, local_device_count) del rng # Load datasets logging.info('Loading dataset.') # Make sure we don't re-use same data if we load weights or checkpoint seed = config.seed + start_step if config.load_weights: seed = seed + hash(config.load_weights) name_to_features = task.get_name_to_features(config) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) train_data = data_utils.load_multi_dataset( datasets_config=config.train_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=True, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, ) train_iter = iter(train_data) pad_eval = config.get('pad_eval', False) if pad_eval: logging.info('Eval data is padded such that none of samples are dropped.') else: logging.warn('Eval data is NOT padded -- some samples might be dropped.') eval_data = data_utils.load_multi_dataset( datasets_config=config.eval_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, pad_eval=pad_eval, ) eval_data = list(eval_data) logging.info('Loaded %d samples for evaluation.', len(eval_data)) # Setup postprocessing_fn for saving samples occasionally. if config.get('save_samples_every_steps') is not None: if config.get('save_samples_every_steps') % config.eval_every_steps != 0: raise ValueError( '`eval_every_steps` must divide `save_samples_every_steps`.') postprocessing_fn = task.make_output_postprocess_fn(config) # Training loop logging.info('Starting training.') # Replicate train state. train_state = jax_utils.replicate(train_state) # compile multidevice versions of train/eval/predict step p_train_step = jax.pmap( functools.partial( train_step, model_config=model_config, ), axis_name='batch', donate_argnums=(0,), ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, model_config=model_config, ), axis_name='batch') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and perform a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = jax.tree_map(jnp.asarray, train_iter.get_next()) train_state, metrics = p_train_step( train_state, model_vars, batch, dropout_rngs, ) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) metrics_sums = jax.tree_map(jnp.sum, train_metrics) summary = metric_utils.process_metrics(metrics_sums, prefix='train') summary['learning_rate'] = learning_rate_fn(step) writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_auxiliary = evaluate( eval_step_fn=p_eval_step, train_state=train_state, model_vars=model_vars, eval_data=eval_data, ) writer.write_scalars(step, eval_results) if config.get('save_samples_every_steps') is not None: with report_progress.timed('save_samples'): if config.get('save_first_batch_only', 'True'): postprocessing_input = [eval_auxiliary[0]] eval_processed = [ postprocessing_fn(batch, auxiliary_output) for batch, auxiliary_output in eval_auxiliary ] data_utils.save_samples_to_json(eval_processed, config, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step) if (config.save_checkpoints and save_checkpoint and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving checkpoint at step %s', step) checkpoints.save_checkpoint( config.model_dir, jax_utils.unreplicate(train_state), step, keep=config.get('keep_checkpoints', 1), keep_every_n_steps=config.get('keep_checkpoint_every_steps'), ) save_model = ( config.save_every_steps and (step % config.save_every_steps == 0 or is_last_step) and step != 0) if (save_model and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving weights at step %s', step) save_path = os.path.join(config.model_dir, 'weights', 'step' + str(step)) # By default, save only encoder weights weights = jax_utils.unreplicate(train_state).params['encoder'] checkpoint_utils.save_weights(save_path, weights)
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state
def train_and_evaluate(config, work_dir, try_checkpoint=True): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. work_dir: Directory where the tensorboard summaries are written to. try_checkpoint: Should try to load checkpoint (usually enabled, practical for debugging purposes to disable). Returns: The train state (which includes the `.params`). """ # Init rng key. rng = jax.random.PRNGKey(config.seed) data_rng, rng = jax.random.split(rng) is_first_host = jax.process_index() == 0 if config.dataset.name.endswith('speech_commands09'): ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config) else: raise ValueError(f'Unknown dataset {config.dataset.name}.') # Immediately create infinite iterators. it = jax.tree_map(util_fns.get_iterator, ds) # TODO(agritsenko): Can we fix the ugly nested dicts? config.data_shape = ds_metadata['train']['shape']['inputs'][2:] config.num_classes = ds_metadata['train']['num_classes'] config.sample_rate = ds_metadata['train']['sample_rate'] writer = metric_writers.create_default_writer( work_dir, just_logging=jax.process_index() > 0) rng, init_rng = jax.random.split(rng) model, variables = model_setup(init_rng, config) # From now on we want different rng across hosts: rng = jax.random.fold_in(rng, jax.process_index()) def tx_fn(lr): return optax.adamw(lr, b1=0.9, b2=config.beta2, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create(params=variables['params'], tx_fn=tx_fn) start_step = None if try_checkpoint: state, start_step = checkpoint.restore_from_path(work_dir, state) start_step = start_step or 0 # Use different rngs for train & eval. rng_train, rng_eval, rng_sample = jax.random.split(rng, 3) kl_tracker = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] learning_rate_fn = train_utils.create_learning_rate_scheduler( **config.learning_rate) p_train_step = jax.pmap(functools.partial( train_step, config=config, learning_rate_fn=learning_rate_fn, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None), donate_argnums=(2, )) # The only axes that are broadcasted are the in- and output rng key ones. The # rng is the first arg, and the last return value. p_eval_step = jax.pmap(functools.partial(eval_step, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None)) # Training length. logging.info('Training will start from step %d', start_step) # Replicate state. state = flax.jax_utils.replicate(state) # Setup hooks. hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_host: hooks += [ report_progress, periodic_actions.Profile(logdir=work_dir, num_profile_steps=5) ] with metric_writers.ensure_flushes(writer): batch_metrics = [] for step in range(start_step, config.num_train_steps): logging.log_first_n(logging.INFO, f'Train step: {step}', 5) with jax.profiler.StepTraceAnnotation('train', step_num=step): state, metrics, rng_train = p_train_step( rng_train, next(it['train']), state) batch_metrics.append(metrics) # Cycle though hooks. for h in hooks: h(step) is_last_step = step == config.num_train_steps - 1 if (step % config.log_every_steps == 0) or is_last_step: with report_progress.timed('training_metrics'): ################### Process batch metrics ############################ batch_metrics = jax.device_get( flax.jax_utils.unreplicate(batch_metrics)) if 't_batch' in metrics: # TODO(agritsenko): Factor out into a separate function. # This processes the loss per t, although two nested for-loops # (counting the one inside kl_tracker), it actually does not hurt # timing performance meaningfully. batch_t = [ metrics['t_batch'].reshape(-1) for metrics in batch_metrics ] batch_nelbo_per_t = [ metrics['nelbo_per_t_batch'].reshape(-1) for metrics in batch_metrics ] for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t): kl_tracker.update(t, nelbo_per_t) ################### Process batch metrics ############################ metrics = { key: np.mean([metrics[key] for metrics in batch_metrics]) for key in batch_metrics[0] if 'batch' not in key } # Metric logging. if is_first_host: log_standard_metrics(writer, step, train_metrics=metrics) batch_metrics = [] if config.eval_every_steps and ( (step % config.eval_every_steps == 0) or is_last_step): with report_progress.timed('eval'): ####################### Run evaluation ############################### metrics, rng_eval = eval_model( p_eval_step, rng_eval, state, it['eval'], (ds_metadata['eval']['num_batches'] * config.get('num_eval_passes', 1))) # Metric logging. if is_first_host: log_standard_metrics(writer, step, eval_metrics=metrics) # Track KL (unrelated to the eval, but nice to not do every step). kl_values = kl_tracker.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-50:] if config.sample_every_steps and ( (step % config.sample_every_steps == 0) or is_last_step): with report_progress.timed('sample'): ######################### Run sampling ############################### chain = model.sample(jax.random.fold_in(rng_sample, step), state.ema_params, config.sample_batch_size, chain_out_size=config.get( 'chain_out_size', model.num_stages)) if is_first_host: chain = jax.device_get(chain) long_sample = np.reshape(chain[-1], (1, -1, 1)).astype(np.float32) long_sample = (2. * long_sample) / config.num_classes - 1. writer.write_audios(step, {'samples': long_sample}, sample_rate=config.sample_rate) ######################### Checkpointing ################################# if is_first_host and config.checkpoint_every_steps and ( (step % config.checkpoint_every_steps == 0) or is_last_step): logging.info('Saving checkpoint: step %d', step) with report_progress.timed('checkpoint'): checkpoint.save_checkpoint( work_dir, state=flax.jax_utils.unreplicate(state), step=step) logging.info('Finished saving checkpoint: step %d', step) return state
def main(_): if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching: raise ValueError( "'precrop_iters has no effect when 'batching' the dataset") assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0 logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count()) logging.info("JAX local devices: %r", jax.local_devices()) platform.work_unit().set_task_status( f"host_id: {jax.process_index()}, host_count: {jax.host_count()}") platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.model_dir, "model_dir") os.makedirs(FLAGS.model_dir, exist_ok=True) rng = jax.random.PRNGKey(FLAGS.seed) rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5) rngs = common_utils.shard_prng_key(step_rng) ### Load dataset and data values datasets, counts, optics, render_datasets = get_dataset( FLAGS.data_dir, FLAGS.config, rng=data_rng, num_poses=FLAGS.config.num_poses) train_ds, val_ds, test_ds = datasets *_, test_items = counts hwf, r_hwf, near, far = optics render_ds, render_vdirs_ds, num_poses = render_datasets iter_render_ds = zip(range(num_poses), render_ds) iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds) iter_test_ds = zip(range(test_items), test_ds) img_h, img_w, _ = hwf logging.info("Num poses: %d", num_poses) logging.info("Splits: train - %d, val - %d, test - %d", *counts) logging.info("Images: height %d, width %d, focal %.5f", *hwf) logging.info("Render: height %d, width %d, focal %.5f", *r_hwf) ### Init model parameters and optimizer initialized_ = functools.partial(initialized, model_config=FLAGS.config.model) pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3) views_shape = (FLAGS.config.num_rand, 3) model_coarse, params_coarse = initialized_(rng_coarse, pts_shape, views_shape) schedule_fn = optax.exponential_decay( init_value=FLAGS.config.learning_rate, transition_steps=FLAGS.config.lr_decay * 1000, decay_rate=FLAGS.config.decay_factor, ) tx = optax.adam(learning_rate=schedule_fn) state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None), params={"coarse": params_coarse}, tx=tx) if FLAGS.config.num_importance > 0: pts_shape = ( FLAGS.config.num_rand, FLAGS.config.num_importance + FLAGS.config.num_samples, 3, ) model_fine, params_fine = initialized_(rng_fine, pts_shape, views_shape) state = train_state.TrainState.create( apply_fn=(model_coarse.apply, model_fine.apply), params={ "coarse": params_coarse, "fine": params_fine }, tx=tx, ) state = checkpoints.restore_checkpoint(FLAGS.model_dir, state) start_step = int(state.step) # cycle already seen examples if resuming from checkpoint # (only useful for ensuring deterministic dataset, slow for large start_step) # if start_step != 0: # for _ in range(start_step): # _ = next(train_ds) # parameter_overview.log_parameter_overview(state.optimizer_coarse.target) # if FLAGS.config.num_importance > 0: # parameter_overview.log_parameter_overview(state.optimizer_fine.target) state = jax.device_put_replicated(state, jax.local_devices()) ### Build "pmapped" functions for distributed training train_fn = functools.partial(train_step, near, far, FLAGS.config, schedule_fn) p_train_step = jax.pmap( train_fn, axis_name="batch", in_axes=(0, 0, None, 0), # donate_argnums=(0, 1, 2), ) def render_fn(state, rays): step_fn = functools.partial(eval_step, FLAGS.config, near, far, state) return lax.map(step_fn, rays) p_eval_step = jax.pmap( render_fn, axis_name="batch", # in_axes=(0, 0, None), # donate_argnums=(0, 1)) ) # TODO: add hparams writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) logging.info("Starting training loop.") hooks = [] profiler = periodic_actions.Profile(num_profile_steps=5, logdir=FLAGS.model_dir) report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.config.num_steps, writer=writer) if jax.process_index() == 0: hooks += [profiler, report_progress] train_metrics = [] gen_video_ = functools.partial(gen_video, FLAGS.model_dir) for step in range(start_step, FLAGS.config.num_steps + 1): is_last_step = step == FLAGS.config.num_steps batch = next(train_ds) coords = None if not FLAGS.config.batching: coords = jnp.meshgrid(jnp.arange(img_h), jnp.arange(img_w), indexing="ij") if step < FLAGS.config.precrop_iters: dH = int(img_h // 2 * FLAGS.config.precrop_frac) dW = int(img_w // 2 * FLAGS.config.precrop_frac) coords = jnp.meshgrid( jnp.arange(img_h // 2 - dH, img_h // 2 + dH), jnp.arange(img_w // 2 - dW, img_w // 2 + dW), indexing="ij", ) coords = jnp.stack(coords, axis=-1).reshape([-1, 2]) with jax.profiler.StepTraceAnnotation("train", step_num=step): state, metrics = p_train_step(batch, state, coords, rngs) train_metrics.append(metrics) logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) _ = [h(step) for h in hooks] ### Write train summaries to TB if step % FLAGS.config.i_print == 0 or is_last_step: with report_progress.timed("training_metrics"): train_metrics = common_utils.get_metrics(train_metrics) train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) summary = {f"train/{k}": v for k, v in train_summary.items()} writer.write_scalars(step, summary) train_metrics = [] ### Eval a random validation image and plot it to TB if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step: with report_progress.timed("validation"): inputs = next(val_ds) rays, padding = prepare_render_data(inputs["rays"]._numpy()) outputs = p_eval_step(state, rays) preds, preds_c, z_std = jax.tree_map( lambda x: to_np(x, hwf, padding), outputs) loss = np.mean((preds["rgb"] - inputs["image"])**2) summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) summary = { "val/rgb": to_rgb(preds["rgb"]), "val/target": to_np(inputs["image"], hwf, padding), "val/disp": disp_post(preds["disp"], FLAGS.config), "val/acc": preds["acc"], } if FLAGS.config.num_importance > 0: summary["val/rgb_c"] = to_rgb(preds_c["rgb"]) summary["val/disp_c"] = disp_post(preds_c["disp"], FLAGS.config) summary["val/z_std"] = z_std writer.write_images(step, summary) ### Render a video with test poses if step % FLAGS.config.i_video == 0 and step > 0: with report_progress.timed("video_render"): logging.info("Rendering video at step %d", step) rgb_list = [] disp_list = [] for idx, inputs in tqdm(iter_render_ds, desc="Rays render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding), preds) rgb_list.append(preds["rgb"]) disp_list.append(preds["disp"]) gen_video_(np.stack(rgb_list), "rgb", r_hwf, step) disp = np.stack(disp_list) gen_video_(disp_post(disp, FLAGS.config), "disp", r_hwf, step, ch=1) if FLAGS.config.use_viewdirs: rgb_list = [] for idx, inputs in tqdm(iter_vdirs_ds, desc="Viewdirs render"): rays, padding = prepare_render_data( inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) rgb_list.append(to_np(preds["rgb"], r_hwf, padding)) gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step) ### Save images in the test set if step % FLAGS.config.i_testset == 0 and step > 0: with report_progress.timed("test_render"): logging.info("Rendering test set at step %d", step) test_losses = [] for idx, inputs in tqdm(iter_test_ds, desc="Test render"): rays, padding = prepare_render_data(inputs["rays"].numpy()) preds, *_ = p_eval_step(state, rays) save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step, idx) if FLAGS.config.render_factor == 0: loss = np.mean((preds["rgb"] - inputs["image"])**2.0) test_losses.append(loss) if FLAGS.config.render_factor == 0: loss = np.mean(test_losses) summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)} writer.write_scalars(step, summary) writer.flush() ### Save ckpt if step % FLAGS.config.i_weights == 0 or is_last_step: with report_progress.timed("checkpoint"): save_checkpoint(state, FLAGS.model_dir)