Example #1
0
def restore_checkpoint(model, ema, strategy, latest_ckpt=None, optimizer=None):
    if optimizer is None:
        ckpt_func = functools.partial(train_utils.create_checkpoint,
                                      models=model,
                                      ema=ema)
    else:
        ckpt_func = functools.partial(train_utils.create_checkpoint,
                                      models=model,
                                      ema=ema,
                                      optimizer=optimizer)

    checkpoint = train_utils.with_strategy(ckpt_func, strategy)
    if latest_ckpt:
        logging.info('Restoring from pretrained directory: %s', latest_ckpt)
        train_utils.with_strategy(lambda: checkpoint.restore(latest_ckpt),
                                  strategy)
    return checkpoint
Example #2
0
def evaluate(logdir, subset):
    """Executes the evaluation loop."""
    config = FLAGS.config
    strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master,
                                                      FLAGS.devices_per_worker,
                                                      FLAGS.mode,
                                                      FLAGS.accelerator_type)

    def input_fn(_=None):
        return datasets.get_dataset(name=config.dataset,
                                    config=config,
                                    batch_size=config.eval_batch_size,
                                    subset=subset)

    model, optimizer, ema = train_utils.with_strategy(
        lambda: build(config, batch_size, False), strategy)

    metric_keys = ['loss', 'total_loss']
    # metric_keys += model.metric_keys
    metrics = {}
    for metric_key in metric_keys:
        func = functools.partial(tf.keras.metrics.Mean, metric_key)
        curr_metric = train_utils.with_strategy(func, strategy)
        metrics[metric_key] = curr_metric

    checkpoints = train_utils.with_strategy(
        lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy)
    dataset = train_utils.dataset_with_strategy(input_fn, strategy)

    def step_fn(batch):
        _, extra = loss_on_batch(batch, model, config, training=False)

        for metric_key in metric_keys:
            curr_metric = metrics[metric_key]
            curr_scalar = extra['scalar'][metric_key]
            curr_metric.update_state(curr_scalar)

    num_examples = config.eval_num_examples
    eval_step = train_utils.step_with_strategy(step_fn, strategy)
    ckpt_path = None
    wait_max = config.get('eval_checkpoint_wait_secs',
                          config.save_checkpoint_secs * 100)
    is_ema = True if ema else False

    eval_summary_dir = os.path.join(
        logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema))
    writer = tf.summary.create_file_writer(eval_summary_dir)

    while True:
        ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path,
                                                    wait_max)
        logging.info(ckpt_path)
        if ckpt_path is None:
            logging.info('Timed out waiting for checkpoint.')
            break

        train_utils.with_strategy(
            lambda: train_utils.restore(model, checkpoints, logdir, ema),
            strategy)
        data_iterator = iter(dataset)
        num_steps = num_examples // batch_size

        for metric_key, metric in metrics.items():
            metric.reset_states()

        logging.info('Starting evaluation.')
        done = False
        for i in range(0, num_steps, FLAGS.steps_per_summaries):
            start_run = time.time()
            for k in range(min(num_steps - i, FLAGS.steps_per_summaries)):
                try:
                    if k % 10 == 0:
                        logging.info('Step: %d', (i + k + 1))
                    eval_step(data_iterator)
                except (StopIteration, tf.errors.OutOfRangeError):
                    done = True
                    break
            if done:
                break
            bits_per_dim = metrics['loss'].result()
            logging.info(
                'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d',
                bits_per_dim,
                (time.time() - start_run) / FLAGS.steps_per_summaries,
                i + k + 1, num_steps)

        # logging.info('Final Bits/Dim: %.3f', bits_per_dim)
        with writer.as_default():
            for metric_key, metric in metrics.items():
                curr_scalar = metric.result().numpy()
                tf.summary.scalar(metric_key,
                                  curr_scalar,
                                  step=optimizer.iterations)
Example #3
0
def train(logdir):
    config = FLAGS.config
    steps_per_write = FLAGS.steps_per_summaries
    train_utils.write_config(config, logdir)

    strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master,
                                                      FLAGS.devices_per_worker,
                                                      FLAGS.mode,
                                                      FLAGS.accelerator_type)

    def input_fn(input_context=None):
        read_config = None
        if input_context is not None:
            read_config = tfds.ReadConfig(input_context=input_context)

        dataset = datasets.get_dataset(name=FLAGS.dataset,
                                       config=config,
                                       batch_size=config.batch_size,
                                       subset='train',
                                       read_config=read_config,
                                       data_dir=FLAGS.data_dir)
        return dataset

    # DATASET CREATION.
    logging.info('Building dataset.')
    train_dataset = train_utils.dataset_with_strategy(input_fn, strategy)
    data_iterator = iter(train_dataset)

    # MODEL BUILDING
    logging.info('Building model.')
    model, optimizer, ema = train_utils.with_strategy(
        lambda: build(config, batch_size, True), strategy)
    model.summary(120, print_fn=logging.info)

    # METRIC CREATION.
    metrics = {}
    metric_keys = ['loss', 'total_loss']
    metric_keys += model.metric_keys
    for metric_key in metric_keys:
        func = functools.partial(tf.keras.metrics.Mean, metric_key)
        curr_metric = train_utils.with_strategy(func, strategy)
        metrics[metric_key] = curr_metric

    # CHECKPOINTING LOGIC.
    if FLAGS.pretrain_dir is not None:
        pretrain_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir)
        assert pretrain_ckpt

        # Load the entire model without the optimizer from the checkpoints.
        restore_checkpoint(model, ema, strategy, pretrain_ckpt, optimizer=None)
        # New tf.train.Checkpoint instance with a reset optimizer.
        checkpoint = restore_checkpoint(model,
                                        ema,
                                        strategy,
                                        latest_ckpt=None,
                                        optimizer=optimizer)
    else:
        latest_ckpt = tf.train.latest_checkpoint(logdir)
        checkpoint = restore_checkpoint(model,
                                        ema,
                                        strategy,
                                        latest_ckpt,
                                        optimizer=optimizer)

    checkpoint = tf.train.CheckpointManager(checkpoint,
                                            directory=logdir,
                                            checkpoint_name='model',
                                            max_to_keep=10)
    if optimizer.iterations.numpy() == 0:
        checkpoint_name = checkpoint.save()
        logging.info('Saved checkpoint to %s', checkpoint_name)

    train_summary_dir = os.path.join(logdir, 'train_summaries')
    writer = tf.summary.create_file_writer(train_summary_dir)
    start_time = time.time()

    logging.info('Start Training.')

    # This hack of wrapping up multiple train steps with a tf.function call
    # speeds up training significantly.
    # See: https://www.tensorflow.org/guide/tpu#improving_performance_by_multiple_steps_within_tffunction # pylint: disable=line-too-long
    @tf.function
    def train_multiple_steps(iterator, steps_per_epoch):

        train_step_f = train_step(config, model, optimizer, metrics, ema,
                                  strategy)

        for _ in range(steps_per_epoch):
            train_step_f(iterator)

    while optimizer.iterations.numpy() < config.get('max_train_steps',
                                                    1000000):
        num_train_steps = optimizer.iterations

        for metric_key in metric_keys:
            metrics[metric_key].reset_states()

        start_run = time.time()

        train_multiple_steps(data_iterator,
                             tf.convert_to_tensor(steps_per_write))

        steps_per_sec = steps_per_write / (time.time() - start_run)
        with writer.as_default():
            for metric_key, metric in metrics.items():
                metric_np = metric.result().numpy()
                tf.summary.scalar(metric_key, metric_np, step=num_train_steps)

                if metric_key == 'total_loss':
                    logging.info(
                        'Loss: %.3f bits/dim, Speed: %.3f steps/second',
                        metric_np, steps_per_sec)

        if time.time() - start_time > config.save_checkpoint_secs:
            checkpoint_name = checkpoint.save()
            logging.info('Saved checkpoint to %s', checkpoint_name)
            start_time = time.time()