def restore_checkpoint(model, ema, strategy, latest_ckpt=None, optimizer=None): if optimizer is None: ckpt_func = functools.partial(train_utils.create_checkpoint, models=model, ema=ema) else: ckpt_func = functools.partial(train_utils.create_checkpoint, models=model, ema=ema, optimizer=optimizer) checkpoint = train_utils.with_strategy(ckpt_func, strategy) if latest_ckpt: logging.info('Restoring from pretrained directory: %s', latest_ckpt) train_utils.with_strategy(lambda: checkpoint.restore(latest_ckpt), strategy) return checkpoint
def evaluate(logdir, subset): """Executes the evaluation loop.""" config = FLAGS.config strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master, FLAGS.devices_per_worker, FLAGS.mode, FLAGS.accelerator_type) def input_fn(_=None): return datasets.get_dataset(name=config.dataset, config=config, batch_size=config.eval_batch_size, subset=subset) model, optimizer, ema = train_utils.with_strategy( lambda: build(config, batch_size, False), strategy) metric_keys = ['loss', 'total_loss'] # metric_keys += model.metric_keys metrics = {} for metric_key in metric_keys: func = functools.partial(tf.keras.metrics.Mean, metric_key) curr_metric = train_utils.with_strategy(func, strategy) metrics[metric_key] = curr_metric checkpoints = train_utils.with_strategy( lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy) dataset = train_utils.dataset_with_strategy(input_fn, strategy) def step_fn(batch): _, extra = loss_on_batch(batch, model, config, training=False) for metric_key in metric_keys: curr_metric = metrics[metric_key] curr_scalar = extra['scalar'][metric_key] curr_metric.update_state(curr_scalar) num_examples = config.eval_num_examples eval_step = train_utils.step_with_strategy(step_fn, strategy) ckpt_path = None wait_max = config.get('eval_checkpoint_wait_secs', config.save_checkpoint_secs * 100) is_ema = True if ema else False eval_summary_dir = os.path.join( logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema)) writer = tf.summary.create_file_writer(eval_summary_dir) while True: ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path, wait_max) logging.info(ckpt_path) if ckpt_path is None: logging.info('Timed out waiting for checkpoint.') break train_utils.with_strategy( lambda: train_utils.restore(model, checkpoints, logdir, ema), strategy) data_iterator = iter(dataset) num_steps = num_examples // batch_size for metric_key, metric in metrics.items(): metric.reset_states() logging.info('Starting evaluation.') done = False for i in range(0, num_steps, FLAGS.steps_per_summaries): start_run = time.time() for k in range(min(num_steps - i, FLAGS.steps_per_summaries)): try: if k % 10 == 0: logging.info('Step: %d', (i + k + 1)) eval_step(data_iterator) except (StopIteration, tf.errors.OutOfRangeError): done = True break if done: break bits_per_dim = metrics['loss'].result() logging.info( 'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d', bits_per_dim, (time.time() - start_run) / FLAGS.steps_per_summaries, i + k + 1, num_steps) # logging.info('Final Bits/Dim: %.3f', bits_per_dim) with writer.as_default(): for metric_key, metric in metrics.items(): curr_scalar = metric.result().numpy() tf.summary.scalar(metric_key, curr_scalar, step=optimizer.iterations)
def train(logdir): config = FLAGS.config steps_per_write = FLAGS.steps_per_summaries train_utils.write_config(config, logdir) strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master, FLAGS.devices_per_worker, FLAGS.mode, FLAGS.accelerator_type) def input_fn(input_context=None): read_config = None if input_context is not None: read_config = tfds.ReadConfig(input_context=input_context) dataset = datasets.get_dataset(name=FLAGS.dataset, config=config, batch_size=config.batch_size, subset='train', read_config=read_config, data_dir=FLAGS.data_dir) return dataset # DATASET CREATION. logging.info('Building dataset.') train_dataset = train_utils.dataset_with_strategy(input_fn, strategy) data_iterator = iter(train_dataset) # MODEL BUILDING logging.info('Building model.') model, optimizer, ema = train_utils.with_strategy( lambda: build(config, batch_size, True), strategy) model.summary(120, print_fn=logging.info) # METRIC CREATION. metrics = {} metric_keys = ['loss', 'total_loss'] metric_keys += model.metric_keys for metric_key in metric_keys: func = functools.partial(tf.keras.metrics.Mean, metric_key) curr_metric = train_utils.with_strategy(func, strategy) metrics[metric_key] = curr_metric # CHECKPOINTING LOGIC. if FLAGS.pretrain_dir is not None: pretrain_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir) assert pretrain_ckpt # Load the entire model without the optimizer from the checkpoints. restore_checkpoint(model, ema, strategy, pretrain_ckpt, optimizer=None) # New tf.train.Checkpoint instance with a reset optimizer. checkpoint = restore_checkpoint(model, ema, strategy, latest_ckpt=None, optimizer=optimizer) else: latest_ckpt = tf.train.latest_checkpoint(logdir) checkpoint = restore_checkpoint(model, ema, strategy, latest_ckpt, optimizer=optimizer) checkpoint = tf.train.CheckpointManager(checkpoint, directory=logdir, checkpoint_name='model', max_to_keep=10) if optimizer.iterations.numpy() == 0: checkpoint_name = checkpoint.save() logging.info('Saved checkpoint to %s', checkpoint_name) train_summary_dir = os.path.join(logdir, 'train_summaries') writer = tf.summary.create_file_writer(train_summary_dir) start_time = time.time() logging.info('Start Training.') # This hack of wrapping up multiple train steps with a tf.function call # speeds up training significantly. # See: https://www.tensorflow.org/guide/tpu#improving_performance_by_multiple_steps_within_tffunction # pylint: disable=line-too-long @tf.function def train_multiple_steps(iterator, steps_per_epoch): train_step_f = train_step(config, model, optimizer, metrics, ema, strategy) for _ in range(steps_per_epoch): train_step_f(iterator) while optimizer.iterations.numpy() < config.get('max_train_steps', 1000000): num_train_steps = optimizer.iterations for metric_key in metric_keys: metrics[metric_key].reset_states() start_run = time.time() train_multiple_steps(data_iterator, tf.convert_to_tensor(steps_per_write)) steps_per_sec = steps_per_write / (time.time() - start_run) with writer.as_default(): for metric_key, metric in metrics.items(): metric_np = metric.result().numpy() tf.summary.scalar(metric_key, metric_np, step=num_train_steps) if metric_key == 'total_loss': logging.info( 'Loss: %.3f bits/dim, Speed: %.3f steps/second', metric_np, steps_per_sec) if time.time() - start_time > config.save_checkpoint_secs: checkpoint_name = checkpoint.save() logging.info('Saved checkpoint to %s', checkpoint_name) start_time = time.time()