def _get_input_iterator( self, input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset], strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]: """Returns distributed dataset iterator. Args: input_fn: (params: dict) -> tf.data.Dataset. strategy: an instance of tf.distribute.Strategy. Returns: An iterator that yields input tensors. """ if input_fn is None: return None # When training with multiple TPU workers, datasets needs to be cloned # across workers. Since Dataset instance cannot be cloned in eager mode, # we instead pass callable that returns a dataset. input_data = input_fn(self._params) return iter(strategy.experimental_distribute_dataset(input_data))
def run_train_loop( train_dataset_builder: ub.datasets.BaseDataset, validation_dataset_builder: Optional[ub.datasets.BaseDataset], test_dataset_builder: ub.datasets.BaseDataset, batch_size: int, eval_batch_size: int, model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, eval_frequency: int, log_frequency: int, trial_dir: Optional[str], train_steps: int, mode: str, strategy: tf.distribute.Strategy, metrics: Dict[str, Union[tf.keras.metrics.Metric, rm.metrics.KerasMetric]], hparams: Dict[str, Any]): """Train, possibly evaluate the model, and record metrics.""" checkpoint_manager = None last_checkpoint_step = 0 if trial_dir: # TODO(znado): add train_iterator to this once DistributedIterators are # checkpointable. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, trial_dir, max_to_keep=None) checkpoint_path = tf.train.latest_checkpoint(trial_dir) if checkpoint_path: last_checkpoint_step = int(checkpoint_path.split('-')[-1]) if last_checkpoint_step >= train_steps: # If we have already finished training, exit. logging.info( 'Training has already finished at step %d. Exiting.', train_steps) return elif last_checkpoint_step > 0: # Restore from where we previously finished. checkpoint.restore(checkpoint_manager.latest_checkpoint) logging.info('Resuming training from step %d.', last_checkpoint_step) train_dataset = train_dataset_builder.load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) train_iterator = iter(train_dataset) iterations_per_loop = min(eval_frequency, log_frequency) # We can only run `iterations_per_loop` steps at a time, because we cannot # checkpoint the model inside a tf.function. train_step_fn = _train_step_fn(model, optimizer, strategy, metrics, iterations_per_loop=iterations_per_loop) if trial_dir: train_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'train')) else: train_summary_writer = None val_summary_writer = None test_summary_writer = None if mode == 'train_and_eval': (val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer) = eval_lib.setup_eval( validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=eval_batch_size, strategy=strategy, trial_dir=trial_dir, model=model, metrics=metrics) # Each call to train_step_fn will run iterations_per_loop steps. num_train_fn_steps = train_steps // iterations_per_loop # We are guaranteed that `last_checkpoint_step` will be divisible by # `iterations_per_loop` because that is how frequently we checkpoint. start_train_fn_step = last_checkpoint_step // iterations_per_loop for train_fn_step in range(start_train_fn_step, num_train_fn_steps): current_step = train_fn_step * iterations_per_loop # Checkpoint at the start of the step, before the training op is run. if (checkpoint_manager and current_step % eval_frequency == 0 and current_step != last_checkpoint_step): checkpoint_manager.save(checkpoint_number=current_step) if mode == 'train_and_eval' and current_step % eval_frequency == 0: eval_lib.run_eval_epoch( val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer, current_step, hparams=None) # Only write hparams on the last step. train_step_outputs = train_step_fn(train_iterator) if current_step % log_frequency == 0: _write_summaries(train_step_outputs, current_step, train_summary_writer) train_step_outputs_np = { k: v.numpy() for k, v in train_step_outputs.items() } logging.info('Training metrics for step %d: %s', current_step, train_step_outputs_np) if train_steps % iterations_per_loop != 0: remainder_train_step_fn = _train_step_fn( model, optimizer, strategy, metrics, iterations_per_loop=train_steps % iterations_per_loop) train_step_outputs = remainder_train_step_fn(train_iterator) # Always evaluate and record metrics at the end of training. _write_summaries(train_step_outputs, train_steps, train_summary_writer, hparams) train_step_outputs_np = { k: v.numpy() for k, v in train_step_outputs.items() } logging.info('Training metrics for step %d: %s', current_step, train_step_outputs_np) if mode == 'train_and_eval': eval_lib.run_eval_epoch(val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer, train_steps, hparams=hparams) # Save checkpoint at the end of training. if checkpoint_manager: checkpoint_manager.save(checkpoint_number=train_steps)