def test_set_prediction_model(self):
        """Test set prediction model."""

        learning_rate = 0.001
        resolution = (128, 128)
        batch_size = 2
        num_slots = 3
        num_iterations = 2

        optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=1e-08)
        model = model_utils.build_model(resolution,
                                        batch_size,
                                        num_slots,
                                        num_iterations,
                                        model_type="set_prediction")

        input_shape = (batch_size, resolution[0], resolution[1], 3)
        random_input = tf.random.uniform(input_shape)
        output_shape = (batch_size, num_slots, 19)
        random_output = tf.random.uniform(output_shape)

        with tf.GradientTape() as tape:
            preds = model(random_input, training=True)
            loss_value = utils.hungarian_huber_loss(preds, random_output)

        # Get and apply gradients.
        gradients = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        assert True  # If we make it to this line, we're all good!
    def test_object_discovery_model(self):
        """Test object discovery model."""

        learning_rate = 0.001
        resolution = (128, 128)
        batch_size = 2
        num_slots = 3
        num_iterations = 2

        optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=1e-08)
        model = model_utils.build_model(resolution,
                                        batch_size,
                                        num_slots,
                                        num_iterations,
                                        model_type="object_discovery")

        input_shape = (batch_size, resolution[0], resolution[1], 3)
        random_input = tf.random.uniform(input_shape)

        with tf.GradientTape() as tape:
            preds = model(random_input, training=True)
            recon_combined, _, _, _ = preds
            loss_value = utils.l2_loss(random_input, recon_combined)

        # Get and apply gradients.
        gradients = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        assert True  # If we make it to this line, we're all good!
Beispiel #3
0
def load_model():
  """Load the latest checkpoint."""
  # Build the model.
  model = model_utils.build_model(
      resolution=(128, 128), batch_size=FLAGS.batch_size,
      num_slots=FLAGS.num_slots, num_iterations=FLAGS.num_iterations,
      model_type="set_prediction")
  # Load the weights.
  ckpt = tf.train.Checkpoint(network=model)
  ckpt_manager = tf.train.CheckpointManager(
      ckpt, directory=FLAGS.checkpoint_dir, max_to_keep=5)
  if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
  else:
    raise ValueError("Failed to load checkpoint.")
  return model
def load_model(checkpoint_dir, num_slots=11, num_iters=3, batch_size=16):
    resolution = (128, 128)
    model = model_utils.build_model(resolution,
                                    batch_size,
                                    num_slots,
                                    num_iters,
                                    model_type="object_discovery")

    ckpt = tf.train.Checkpoint(network=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              directory=checkpoint_dir,
                                              max_to_keep=5)

    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)

    return model
Beispiel #5
0
def main(argv):
    del argv
    # Hyperparameters of the model.
    batch_size = FLAGS.batch_size
    num_slots = FLAGS.num_slots
    num_iterations = FLAGS.num_iterations
    base_learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    warmup_steps = FLAGS.warmup_steps
    decay_rate = FLAGS.decay_rate
    decay_steps = FLAGS.decay_steps
    tf.random.set_seed(FLAGS.seed)
    resolution = (128, 128)

    # Build dataset iterators, optimizers and model.
    data_iterator = data_utils.build_clevr_iterator(batch_size,
                                                    split="train",
                                                    resolution=resolution,
                                                    shuffle=True,
                                                    max_n_objects=6,
                                                    get_properties=False,
                                                    apply_crop=True)

    optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08)

    model = model_utils.build_model(resolution,
                                    batch_size,
                                    num_slots,
                                    num_iterations,
                                    model_type="object_discovery")

    # Prepare checkpoint manager.
    global_step = tf.Variable(0,
                              trainable=False,
                              name="global_step",
                              dtype=tf.int64)
    ckpt = tf.train.Checkpoint(network=model,
                               optimizer=optimizer,
                               global_step=global_step)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt,
                                              directory=FLAGS.model_dir,
                                              max_to_keep=5)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    if ckpt_manager.latest_checkpoint:
        logging.info("Restored from %s", ckpt_manager.latest_checkpoint)
    else:
        logging.info("Initializing from scratch.")

    start = time.time()
    for _ in range(num_train_steps):
        batch = next(data_iterator)

        # Learning rate warm-up.
        if global_step < warmup_steps:
            learning_rate = base_learning_rate * tf.cast(
                global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)
        else:
            learning_rate = base_learning_rate
        learning_rate = learning_rate * (decay_rate**(tf.cast(
            global_step, tf.float32) / tf.cast(decay_steps, tf.float32)))
        optimizer.lr = learning_rate.numpy()

        loss_value = train_step(batch, model, optimizer)

        # Update the global step. We update it before logging the loss and saving
        # the model so that the last checkpoint is saved at the last iteration.
        global_step.assign_add(1)

        # Log the training loss.
        if not global_step % 100:
            logging.info("Step: %s, Loss: %.6f, Time: %s", global_step.numpy(),
                         loss_value,
                         datetime.timedelta(seconds=time.time() - start))

        # We save the checkpoints every 1000 iterations.
        if not global_step % 1000:
            # Save the checkpoint of the model.
            saved_ckpt = ckpt_manager.save()
            logging.info("Saved checkpoint: %s", saved_ckpt)