def setup_eval( dataset_builder: ub.datasets.BaseDataset, strategy, trial_dir: Optional[str], model: tf.keras.Model, metrics: Dict[str, tf.keras.metrics.Metric]) -> _EvalSetupResult: """Setup the test and optionally validation loggers, step fns and datasets.""" test_dataset = dataset_builder.build('test') test_dataset = strategy.experimental_distribute_dataset(test_dataset) if trial_dir: test_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'test')) else: test_summary_writer = None num_test_steps = ( dataset_builder.info['num_test_examples'] // dataset_builder.eval_batch_size) test_fn = eval_step_fn( model, strategy, metrics, iterations_per_loop=num_test_steps) # Have to have separate val_fn and test_fn because otherwise tf.function # retraces the function each time, which is very slow, because we are passing # in a Python dict of metrics and int for iterations_per_loop. val_fn = None val_dataset = None val_summary_writer = None if dataset_builder.info['num_validation_examples'] > 0: num_val_steps = ( dataset_builder.info['num_validation_examples'] // dataset_builder.eval_batch_size) val_dataset = dataset_builder.build('validation') val_dataset = strategy.experimental_distribute_dataset(val_dataset) if trial_dir: val_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'validation')) if num_val_steps == num_test_steps: val_fn = test_fn else: # The metrics are reset at the start of each call to {val,test}_fn, so # reusing them is safe. val_fn = eval_step_fn( model, strategy, metrics, iterations_per_loop=num_val_steps) return ( val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer)
def setup_eval( validation_dataset_builder: Optional[ub.datasets.BaseDataset], test_dataset_builder: ub.datasets.BaseDataset, batch_size: int, strategy, trial_dir: str, model: tf.keras.Model, loss_fn, metric_names: List[str]) -> _EvalSetupResult: """Setup the test and optionally validation loggers, step fns and datasets.""" test_dataset = test_dataset_builder.load(batch_size=batch_size) test_dataset = strategy.experimental_distribute_dataset(test_dataset) test_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'summaries/test')) num_test_steps = test_dataset_builder.num_examples // batch_size test_metrics = { name: tf.keras.metrics.Mean(name, dtype=tf.float32) for name in metric_names } test_fn = eval_step_fn( model, loss_fn, strategy, test_metrics, iterations_per_loop=num_test_steps) # Have to have separate val_fn and test_fn because otherwise tf.function # retraces the function each time, which is very slow, because we are passing # in a Python dict of metrics and int for iterations_per_loop. val_fn = None val_dataset = None val_summary_writer = None if validation_dataset_builder: num_val_steps = validation_dataset_builder.num_examples // batch_size val_dataset = validation_dataset_builder.load(batch_size=batch_size) val_dataset = strategy.experimental_distribute_dataset(val_dataset) val_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'summaries/val')) if num_val_steps == num_test_steps: val_fn = test_fn else: # The metrics are reset at the start of each call to {val,test}_fn, so # reusing them is safe. val_fn = eval_step_fn( model, loss_fn, strategy, test_metrics, iterations_per_loop=num_val_steps) return ( val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer)
def setup_eval(validation_dataset_builder: Optional[ub.datasets.BaseDataset], test_dataset_builder: ub.datasets.BaseDataset, batch_size: int, strategy, trial_dir: str, model: tf.keras.Model, metrics: Dict[str, tf.keras.metrics.Metric], ood_dataset_builder: Optional[ub.datasets.BaseDataset] = None, ood_metrics: Dict[str, tf.keras.metrics.Metric] = None, mean_field_factor: float = -1) -> _EvalSetupResult: """Setup the test and optionally validation loggers, step fns and datasets.""" test_dataset = test_dataset_builder.load(batch_size=batch_size) test_dataset = strategy.experimental_distribute_dataset(test_dataset) test_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'test')) num_test_steps = test_dataset_builder.num_examples // batch_size test_fn = eval_step_fn(model, strategy, metrics, iterations_per_loop=num_test_steps, mean_field_factor=mean_field_factor) ood_fn = None ood_dataset = None ood_summary_writer = None if ((ood_dataset_builder and not ood_metrics) or (not ood_dataset_builder and ood_metrics)): raise ValueError('Both ood_dataset_builder and ood_metrics must be' ' specified.') if ood_dataset_builder: ood_dataset = ood_dataset_builder.load(batch_size=batch_size) ood_dataset = strategy.experimental_distribute_dataset(ood_dataset) ood_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'ood')) num_ood_steps = ood_dataset_builder.num_examples // batch_size ood_fn = eval_step_fn(model, strategy, ood_metrics, iterations_per_loop=num_ood_steps, label_key='is_in_distribution', mean_field_factor=mean_field_factor) # Have to have separate val_fn and test_fn because otherwise tf.function # retraces the function each time, which is very slow, because we are passing # in a Python dict of metrics and int for iterations_per_loop. val_fn = None val_dataset = None val_summary_writer = None if validation_dataset_builder: num_val_steps = validation_dataset_builder.num_examples // batch_size val_dataset = validation_dataset_builder.load(batch_size=batch_size) val_dataset = strategy.experimental_distribute_dataset(val_dataset) val_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'validation')) if num_val_steps == num_test_steps: val_fn = test_fn else: # The metrics are reset at the start of each call to {val,test}_fn, so # reusing them is safe. val_fn = eval_step_fn(model, strategy, metrics, iterations_per_loop=num_val_steps, mean_field_factor=mean_field_factor) return (test_fn, test_dataset, test_summary_writer, val_fn, val_dataset, val_summary_writer, ood_fn, ood_dataset, ood_summary_writer)
def run_train_loop( train_dataset_builder: ub.datasets.BaseDataset, validation_dataset_builder: Optional[ub.datasets.BaseDataset], test_dataset_builder: ub.datasets.BaseDataset, batch_size: int, eval_batch_size: int, model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, eval_frequency: int, log_frequency: int, trial_dir: Optional[str], train_steps: int, mode: str, strategy: tf.distribute.Strategy, metrics: Dict[str, Union[tf.keras.metrics.Metric, rm.metrics.KerasMetric]], hparams: Dict[str, Any]): """Train, possibly evaluate the model, and record metrics.""" checkpoint_manager = None last_checkpoint_step = 0 if trial_dir: # TODO(znado): add train_iterator to this once DistributedIterators are # checkpointable. checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, trial_dir, max_to_keep=None) checkpoint_path = tf.train.latest_checkpoint(trial_dir) if checkpoint_path: last_checkpoint_step = int(checkpoint_path.split('-')[-1]) if last_checkpoint_step >= train_steps: # If we have already finished training, exit. logging.info( 'Training has already finished at step %d. Exiting.', train_steps) return elif last_checkpoint_step > 0: # Restore from where we previously finished. checkpoint.restore(checkpoint_manager.latest_checkpoint) logging.info('Resuming training from step %d.', last_checkpoint_step) train_dataset = train_dataset_builder.load(batch_size=batch_size) train_dataset = strategy.experimental_distribute_dataset(train_dataset) train_iterator = iter(train_dataset) iterations_per_loop = min(eval_frequency, log_frequency) # We can only run `iterations_per_loop` steps at a time, because we cannot # checkpoint the model inside a tf.function. train_step_fn = _train_step_fn(model, optimizer, strategy, metrics, iterations_per_loop=iterations_per_loop) if trial_dir: train_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'train')) else: train_summary_writer = None val_summary_writer = None test_summary_writer = None if mode == 'train_and_eval': (val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer) = eval_lib.setup_eval( validation_dataset_builder=validation_dataset_builder, test_dataset_builder=test_dataset_builder, batch_size=eval_batch_size, strategy=strategy, trial_dir=trial_dir, model=model, metrics=metrics) # Each call to train_step_fn will run iterations_per_loop steps. num_train_fn_steps = train_steps // iterations_per_loop # We are guaranteed that `last_checkpoint_step` will be divisible by # `iterations_per_loop` because that is how frequently we checkpoint. start_train_fn_step = last_checkpoint_step // iterations_per_loop for train_fn_step in range(start_train_fn_step, num_train_fn_steps): current_step = train_fn_step * iterations_per_loop # Checkpoint at the start of the step, before the training op is run. if (checkpoint_manager and current_step % eval_frequency == 0 and current_step != last_checkpoint_step): checkpoint_manager.save(checkpoint_number=current_step) if mode == 'train_and_eval' and current_step % eval_frequency == 0: eval_lib.run_eval_epoch( val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer, current_step, hparams=None) # Only write hparams on the last step. train_step_outputs = train_step_fn(train_iterator) if current_step % log_frequency == 0: _write_summaries(train_step_outputs, current_step, train_summary_writer) train_step_outputs_np = { k: v.numpy() for k, v in train_step_outputs.items() } logging.info('Training metrics for step %d: %s', current_step, train_step_outputs_np) if train_steps % iterations_per_loop != 0: remainder_train_step_fn = _train_step_fn( model, optimizer, strategy, metrics, iterations_per_loop=train_steps % iterations_per_loop) train_step_outputs = remainder_train_step_fn(train_iterator) # Always evaluate and record metrics at the end of training. _write_summaries(train_step_outputs, train_steps, train_summary_writer, hparams) train_step_outputs_np = { k: v.numpy() for k, v in train_step_outputs.items() } logging.info('Training metrics for step %d: %s', current_step, train_step_outputs_np) if mode == 'train_and_eval': eval_lib.run_eval_epoch(val_fn, val_dataset, val_summary_writer, test_fn, test_dataset, test_summary_writer, train_steps, hparams=hparams) # Save checkpoint at the end of training. if checkpoint_manager: checkpoint_manager.save(checkpoint_number=train_steps)
def setup_eval( dataset_builder: ub.datasets.BaseDataset, strategy, trial_dir: str, model: tf.keras.Model, metrics: Dict[str, tf.keras.metrics.Metric], ood_dataset_builder: ub.datasets.BaseDataset = None, ood_metrics: Dict[str, tf.keras.metrics.Metric] = None) -> _EvalSetupResult: """Setup the test and optionally validation loggers, step fns and datasets.""" test_dataset = dataset_builder.build('test') test_dataset = strategy.experimental_distribute_dataset(test_dataset) test_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'test')) num_test_steps = (dataset_builder.info['num_test_examples'] // dataset_builder.eval_batch_size) test_fn = eval_step_fn(model, strategy, metrics, iterations_per_loop=num_test_steps) ood_fn = None ood_dataset = None ood_summary_writer = None if ((ood_dataset_builder and not ood_metrics) or (not ood_dataset_builder and ood_metrics)): raise ValueError('Both ood_dataset_builder and ood_metrics must be' ' specified.') if ood_dataset_builder: ood_dataset_in = dataset_builder.build('test', ood_split='in') ood_dataset_out = ood_dataset_builder.build('test', ood_split='ood') ood_dataset = ood_dataset_in.concatenate(ood_dataset_out) ood_dataset = strategy.experimental_distribute_dataset(ood_dataset) ood_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'ood')) num_test_steps = (dataset_builder.info['num_test_examples'] // dataset_builder.eval_batch_size) num_ood_steps = (ood_dataset_builder.info['num_ood_examples'] // ood_dataset_builder.eval_batch_size) ood_fn = eval_step_fn(model, strategy, ood_metrics, iterations_per_loop=num_test_steps + num_ood_steps, label_key='is_in_distribution') # Have to have separate val_fn and test_fn because otherwise tf.function # retraces the function each time, which is very slow, because we are passing # in a Python dict of metrics and int for iterations_per_loop. val_fn = None val_dataset = None val_summary_writer = None if dataset_builder.info['num_validation_examples'] > 0: num_val_steps = (dataset_builder.info['num_validation_examples'] // dataset_builder.eval_batch_size) val_dataset = dataset_builder.build('validation') val_dataset = strategy.experimental_distribute_dataset(val_dataset) val_summary_writer = tf.summary.create_file_writer( os.path.join(trial_dir, 'validation')) if num_val_steps == num_test_steps: val_fn = test_fn else: # The metrics are reset at the start of each call to {val,test}_fn, so # reusing them is safe. val_fn = eval_step_fn(model, strategy, metrics, iterations_per_loop=num_val_steps) return (test_fn, test_dataset, test_summary_writer, val_fn, val_dataset, val_summary_writer, ood_fn, ood_dataset, ood_summary_writer)