def train_step(config, model, optimizer, metrics, ema=None, strategy=None): """Training StepFn.""" def step_fn(inputs): """Per-Replica StepFn.""" with tf.GradientTape() as tape: loss, extra = loss_on_batch(inputs, model, config, training=True) scaled_loss = loss if strategy: scaled_loss /= float(strategy.num_replicas_in_sync) grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) for metric_key, metric in metrics.items(): metric.update_state(extra['scalar'][metric_key]) if ema is not None: ema.apply(model.trainable_variables) return loss return train_utils.step_with_strategy(step_fn, strategy)
def evaluate(logdir, subset): """Executes the evaluation loop.""" config = FLAGS.config strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master, FLAGS.devices_per_worker, FLAGS.mode, FLAGS.accelerator_type) def input_fn(_=None): return datasets.get_dataset(name=config.dataset, config=config, batch_size=config.eval_batch_size, subset=subset) model, optimizer, ema = train_utils.with_strategy( lambda: build(config, batch_size, False), strategy) metric_keys = ['loss', 'total_loss'] # metric_keys += model.metric_keys metrics = {} for metric_key in metric_keys: func = functools.partial(tf.keras.metrics.Mean, metric_key) curr_metric = train_utils.with_strategy(func, strategy) metrics[metric_key] = curr_metric checkpoints = train_utils.with_strategy( lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy) dataset = train_utils.dataset_with_strategy(input_fn, strategy) def step_fn(batch): _, extra = loss_on_batch(batch, model, config, training=False) for metric_key in metric_keys: curr_metric = metrics[metric_key] curr_scalar = extra['scalar'][metric_key] curr_metric.update_state(curr_scalar) num_examples = config.eval_num_examples eval_step = train_utils.step_with_strategy(step_fn, strategy) ckpt_path = None wait_max = config.get('eval_checkpoint_wait_secs', config.save_checkpoint_secs * 100) is_ema = True if ema else False eval_summary_dir = os.path.join( logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema)) writer = tf.summary.create_file_writer(eval_summary_dir) while True: ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path, wait_max) logging.info(ckpt_path) if ckpt_path is None: logging.info('Timed out waiting for checkpoint.') break train_utils.with_strategy( lambda: train_utils.restore(model, checkpoints, logdir, ema), strategy) data_iterator = iter(dataset) num_steps = num_examples // batch_size for metric_key, metric in metrics.items(): metric.reset_states() logging.info('Starting evaluation.') done = False for i in range(0, num_steps, FLAGS.steps_per_summaries): start_run = time.time() for k in range(min(num_steps - i, FLAGS.steps_per_summaries)): try: if k % 10 == 0: logging.info('Step: %d', (i + k + 1)) eval_step(data_iterator) except (StopIteration, tf.errors.OutOfRangeError): done = True break if done: break bits_per_dim = metrics['loss'].result() logging.info( 'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d', bits_per_dim, (time.time() - start_run) / FLAGS.steps_per_summaries, i + k + 1, num_steps) # logging.info('Final Bits/Dim: %.3f', bits_per_dim) with writer.as_default(): for metric_key, metric in metrics.items(): curr_scalar = metric.result().numpy() tf.summary.scalar(metric_key, curr_scalar, step=optimizer.iterations)