def main(argv):
  del argv
  model = load_model()
  dataset = data_utils.build_clevr_iterator(
      batch_size=FLAGS.batch_size, split="validation", resolution=(128, 128))
  ap = run_eval(model, dataset)
  logging.info(
      "AP@inf: %.2f, AP@1: %.2f, [email protected]: %.2f, [email protected]: %.2f, [email protected]: %.2f.",
      ap[0], ap[1], ap[2], ap[3], ap[4])
Exemple #2
0
def main(argv):
    del argv
    # Hyperparameters of the model.
    batch_size = FLAGS.batch_size
    num_slots = FLAGS.num_slots
    num_iterations = FLAGS.num_iterations
    base_learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    warmup_steps = FLAGS.warmup_steps
    decay_rate = FLAGS.decay_rate
    decay_steps = FLAGS.decay_steps
    tf.random.set_seed(FLAGS.seed)
    resolution = (128, 128)

    # Build dataset iterators, optimizers and model.
    data_iterator = data_utils.build_clevr_iterator(batch_size,
                                                    split="train",
                                                    resolution=resolution,
                                                    shuffle=True,
                                                    max_n_objects=6,
                                                    get_properties=False,
                                                    apply_crop=True)

    optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)

    model = model_utils.build_model(resolution,
                                    batch_size,
                                    num_slots,
                                    num_iterations,
                                    model_type="object_discovery")

    # Prepare checkpoint manager.
    global_step = tf.Variable(0,
                              trainable=False,
                              name="global_step",
                              dtype=tf.int64)
    ckpt = tf.train.Checkpoint(network=model,
                               optimizer=optimizer,
                               global_step=global_step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt,
                                              directory=FLAGS.model_dir,
                                              max_to_keep=5)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")

    start = time.time()
    for _ in range(num_train_steps):
        batch = next(data_iterator)

        # Learning rate warm-up.
        if global_step < warmup_steps:
            learning_rate = base_learning_rate * tf.cast(
                global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
        else:
            learning_rate = base_learning_rate
        learning_rate = learning_rate * (decay_rate**(tf.cast(
            global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
        optimizer.lr = learning_rate.numpy()

        loss_value = train_step(batch, model, optimizer)

        # Update the global step. We update it before logging the loss and saving
        # the model so that the last checkpoint is saved at the last iteration.
        global_step.assign_add(1)

        # Log the training loss.
        if not global_step % 100:
            logging.info("Step: %s, Loss: %.6f, Time: %s", global_step.numpy(),
                         loss_value,
                         datetime.timedelta(seconds=time.time() - start))

        # We save the checkpoints every 1000 iterations.
        if not global_step % 1000:
            # Save the checkpoint of the model.
            saved_ckpt = ckpt_manager.save()
            logging.info("Saved checkpoint: %s", saved_ckpt)