Ejemplo n.º 1
0
def run_eval(model, data_iterator):
    """Run evaluation."""

    if FLAGS.full_eval:  # Evaluate on the full validation set.
        num_eval_batches = 15000 // FLAGS.batch_size
    else:
        # By default, we only test on a single batch for faster evaluation.
        num_eval_batches = 1

    outs = None
    for _ in tf.range(num_eval_batches):
        batch = next(data_iterator)
        if outs is None:
            outs = model(batch["image"], training=False)
            target = batch["target"]
        else:
            new_outs = model(batch["image"], training=False)
            outs = tf.concat([outs, new_outs], axis=0)
            target = tf.concat([target, batch["target"]], axis=0)
    logging.info("Finished getting model predictions.")

    # Compute the AP score.
    ap = [
        utils.average_precision_clevr(outs, target, d)
        for d in [-1., 1., 0.5, 0.25, 0.125]
    ]

    return ap
Ejemplo n.º 2
0
def main(argv):
    del argv
    # Hyperparameters of the model.
    batch_size = FLAGS.batch_size
    num_slots = FLAGS.num_slots
    num_iterations = FLAGS.num_iterations
    base_learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    warmup_steps = FLAGS.warmup_steps
    decay_rate = FLAGS.decay_rate
    decay_steps = FLAGS.decay_steps
    tf.random.set_seed(FLAGS.seed)
    resolution = (128, 128)

    # Build dataset iterators, optimizers and model.
    data_iterator = data_utils.build_clevr_iterator(batch_size,
                                                    split="train",
                                                    resolution=resolution,
                                                    shuffle=True,
                                                    max_n_objects=10,
                                                    get_properties=True,
                                                    apply_crop=False)
    data_iterator_validation = data_utils.build_clevr_iterator(
        batch_size,
        split="train_eval",
        resolution=resolution,
        shuffle=False,
        max_n_objects=10,
        get_properties=True,
        apply_crop=False)

    optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)

    model = model_utils.build_model(resolution,
                                    batch_size,
                                    num_slots,
                                    num_iterations,
                                    model_type="set_prediction")

    # Prepare checkpoint manager.
    global_step = tf.Variable(0,
                              trainable=False,
                              name="global_step",
                              dtype=tf.int64)
    ckpt = tf.train.Checkpoint(network=model,
                               optimizer=optimizer,
                               global_step=global_step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt,
                                              directory=FLAGS.model_dir,
                                              max_to_keep=5)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")

    start = time.time()
    for _ in range(num_train_steps):
        batch = next(data_iterator)

        # Learning rate warm-up.
        if global_step < warmup_steps:
            learning_rate = base_learning_rate * tf.cast(
                global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
        else:
            learning_rate = base_learning_rate
        learning_rate = learning_rate * (decay_rate**(tf.cast(
            global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
        optimizer.lr = learning_rate.numpy()

        loss_value = train_step(batch, model, optimizer)

        # Update the global step. We update it before logging the loss and saving
        # the model so that the last checkpoint is saved at the last iteration.
        global_step.assign_add(1)

        # Log the training loss and validation average precision.
        # We save the checkpoints every 1000 iterations.
        if not global_step % 100:
            logging.info("Step: %s, Loss: %.6f, Time: %s", global_step.numpy(),
                         loss_value,
                         datetime.timedelta(seconds=time.time() - start))
        if not global_step % 1000:
            # For evaluating the AP score, we get a batch from the validation dataset.
            batch = next(data_iterator_validation)
            preds = model(batch["image"], training=False)
            ap = [
                utils.average_precision_clevr(preds, batch["target"], d)
                for d in [-1., 1., 0.5, 0.25, 0.125]
            ]
            logging.info(
                "AP@inf: %.2f, AP@1: %.2f, [email protected]: %.2f, [email protected]: %.2f, [email protected]: %.2f",
                ap[0], ap[1], ap[2], ap[3], ap[4])

            # Save the checkpoint of the model.
            saved_ckpt = ckpt_manager.save()
            logging.info("Saved checkpoint: %s", saved_ckpt)