def test_set_prediction_model(self): """Test set prediction model.""" learning_rate = 0.001 resolution = (128, 128) batch_size = 2 num_slots = 3 num_iterations = 2 optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=1e-08) model = model_utils.build_model(resolution, batch_size, num_slots, num_iterations, model_type="set_prediction") input_shape = (batch_size, resolution[0], resolution[1], 3) random_input = tf.random.uniform(input_shape) output_shape = (batch_size, num_slots, 19) random_output = tf.random.uniform(output_shape) with tf.GradientTape() as tape: preds = model(random_input, training=True) loss_value = utils.hungarian_huber_loss(preds, random_output) # Get and apply gradients. gradients = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) assert True # If we make it to this line, we're all good!
def test_object_discovery_model(self): """Test object discovery model.""" learning_rate = 0.001 resolution = (128, 128) batch_size = 2 num_slots = 3 num_iterations = 2 optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=1e-08) model = model_utils.build_model(resolution, batch_size, num_slots, num_iterations, model_type="object_discovery") input_shape = (batch_size, resolution[0], resolution[1], 3) random_input = tf.random.uniform(input_shape) with tf.GradientTape() as tape: preds = model(random_input, training=True) recon_combined, _, _, _ = preds loss_value = utils.l2_loss(random_input, recon_combined) # Get and apply gradients. gradients = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) assert True # If we make it to this line, we're all good!
def load_model(): """Load the latest checkpoint.""" # Build the model. model = model_utils.build_model( resolution=(128, 128), batch_size=FLAGS.batch_size, num_slots=FLAGS.num_slots, num_iterations=FLAGS.num_iterations, model_type="set_prediction") # Load the weights. ckpt = tf.train.Checkpoint(network=model) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=FLAGS.checkpoint_dir, max_to_keep=5) if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: raise ValueError("Failed to load checkpoint.") return model
def load_model(checkpoint_dir, num_slots=11, num_iters=3, batch_size=16): resolution = (128, 128) model = model_utils.build_model(resolution, batch_size, num_slots, num_iters, model_type="object_discovery") ckpt = tf.train.Checkpoint(network=model) ckpt_manager = tf.train.CheckpointManager(ckpt, directory=checkpoint_dir, max_to_keep=5) if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) return model
def main(argv): del argv # Hyperparameters of the model. batch_size = FLAGS.batch_size num_slots = FLAGS.num_slots num_iterations = FLAGS.num_iterations base_learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps warmup_steps = FLAGS.warmup_steps decay_rate = FLAGS.decay_rate decay_steps = FLAGS.decay_steps tf.random.set_seed(FLAGS.seed) resolution = (128, 128) # Build dataset iterators, optimizers and model. data_iterator = data_utils.build_clevr_iterator(batch_size, split="train", resolution=resolution, shuffle=True, max_n_objects=6, get_properties=False, apply_crop=True) optimizer = tf.keras.optimizers.Adam(base_learning_rate, epsilon=1e-08) model = model_utils.build_model(resolution, batch_size, num_slots, num_iterations, model_type="object_discovery") # Prepare checkpoint manager. global_step = tf.Variable(0, trainable=False, name="global_step", dtype=tf.int64) ckpt = tf.train.Checkpoint(network=model, optimizer=optimizer, global_step=global_step) ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt, directory=FLAGS.model_dir, max_to_keep=5) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: logging.info("Initializing from scratch.") start = time.time() for _ in range(num_train_steps): batch = next(data_iterator) # Learning rate warm-up. if global_step < warmup_steps: learning_rate = base_learning_rate * tf.cast( global_step, tf.float32) / tf.cast(warmup_steps, tf.float32) else: learning_rate = base_learning_rate learning_rate = learning_rate * (decay_rate**(tf.cast( global_step, tf.float32) / tf.cast(decay_steps, tf.float32))) optimizer.lr = learning_rate.numpy() loss_value = train_step(batch, model, optimizer) # Update the global step. We update it before logging the loss and saving # the model so that the last checkpoint is saved at the last iteration. global_step.assign_add(1) # Log the training loss. if not global_step % 100: logging.info("Step: %s, Loss: %.6f, Time: %s", global_step.numpy(), loss_value, datetime.timedelta(seconds=time.time() - start)) # We save the checkpoints every 1000 iterations. if not global_step % 1000: # Save the checkpoint of the model. saved_ckpt = ckpt_manager.save() logging.info("Saved checkpoint: %s", saved_ckpt)