Beispiel #1
0
def train_step(config, model, optimizer, metrics, ema=None, strategy=None):
    """Training StepFn."""
    def step_fn(inputs):
        """Per-Replica StepFn."""
        with tf.GradientTape() as tape:
            loss, extra = loss_on_batch(inputs, model, config, training=True)
            scaled_loss = loss
            if strategy:
                scaled_loss /= float(strategy.num_replicas_in_sync)

        grads = tape.gradient(scaled_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        for metric_key, metric in metrics.items():
            metric.update_state(extra['scalar'][metric_key])

        if ema is not None:
            ema.apply(model.trainable_variables)
        return loss

    return train_utils.step_with_strategy(step_fn, strategy)
Beispiel #2
0
def evaluate(logdir, subset):
    """Executes the evaluation loop."""
    config = FLAGS.config
    strategy, batch_size = train_utils.setup_strategy(config, FLAGS.master,
                                                      FLAGS.devices_per_worker,
                                                      FLAGS.mode,
                                                      FLAGS.accelerator_type)

    def input_fn(_=None):
        return datasets.get_dataset(name=config.dataset,
                                    config=config,
                                    batch_size=config.eval_batch_size,
                                    subset=subset)

    model, optimizer, ema = train_utils.with_strategy(
        lambda: build(config, batch_size, False), strategy)

    metric_keys = ['loss', 'total_loss']
    # metric_keys += model.metric_keys
    metrics = {}
    for metric_key in metric_keys:
        func = functools.partial(tf.keras.metrics.Mean, metric_key)
        curr_metric = train_utils.with_strategy(func, strategy)
        metrics[metric_key] = curr_metric

    checkpoints = train_utils.with_strategy(
        lambda: train_utils.create_checkpoint(model, optimizer, ema), strategy)
    dataset = train_utils.dataset_with_strategy(input_fn, strategy)

    def step_fn(batch):
        _, extra = loss_on_batch(batch, model, config, training=False)

        for metric_key in metric_keys:
            curr_metric = metrics[metric_key]
            curr_scalar = extra['scalar'][metric_key]
            curr_metric.update_state(curr_scalar)

    num_examples = config.eval_num_examples
    eval_step = train_utils.step_with_strategy(step_fn, strategy)
    ckpt_path = None
    wait_max = config.get('eval_checkpoint_wait_secs',
                          config.save_checkpoint_secs * 100)
    is_ema = True if ema else False

    eval_summary_dir = os.path.join(
        logdir, 'eval_{}_summaries_pyk_{}'.format(subset, is_ema))
    writer = tf.summary.create_file_writer(eval_summary_dir)

    while True:
        ckpt_path = train_utils.wait_for_checkpoint(logdir, ckpt_path,
                                                    wait_max)
        logging.info(ckpt_path)
        if ckpt_path is None:
            logging.info('Timed out waiting for checkpoint.')
            break

        train_utils.with_strategy(
            lambda: train_utils.restore(model, checkpoints, logdir, ema),
            strategy)
        data_iterator = iter(dataset)
        num_steps = num_examples // batch_size

        for metric_key, metric in metrics.items():
            metric.reset_states()

        logging.info('Starting evaluation.')
        done = False
        for i in range(0, num_steps, FLAGS.steps_per_summaries):
            start_run = time.time()
            for k in range(min(num_steps - i, FLAGS.steps_per_summaries)):
                try:
                    if k % 10 == 0:
                        logging.info('Step: %d', (i + k + 1))
                    eval_step(data_iterator)
                except (StopIteration, tf.errors.OutOfRangeError):
                    done = True
                    break
            if done:
                break
            bits_per_dim = metrics['loss'].result()
            logging.info(
                'Bits/Dim: %.3f, Speed: %.3f seconds/step, Step: %d/%d',
                bits_per_dim,
                (time.time() - start_run) / FLAGS.steps_per_summaries,
                i + k + 1, num_steps)

        # logging.info('Final Bits/Dim: %.3f', bits_per_dim)
        with writer.as_default():
            for metric_key, metric in metrics.items():
                curr_scalar = metric.result().numpy()
                tf.summary.scalar(metric_key,
                                  curr_scalar,
                                  step=optimizer.iterations)