예제 #1
0
    def __init__(self, cfg, data_shapes):
        """Constructs the autoencoder.

    Args:
      cfg: ConfigDict with model hyperparameters.
      data_shapes: Dict of shapes of model input tensors, as returned by
        datasets.get_sequence_dataset.
    """
        input_sequence = tf.keras.Input(shape=data_shapes['image'][1:],
                                        name='image')

        image_shape = data_shapes['image'][1:]
        keypoints, _ = vision.build_images_to_keypoints_net(
            cfg, image_shape)(input_sequence)
        reconstructed_sequence = vision.build_keypoints_to_images_net(
            cfg, image_shape)([
                keypoints, input_sequence[:, 0, Ellipsis], keypoints[:, 0,
                                                                     Ellipsis]
            ])

        super(Autoencoder, self).__init__(inputs=input_sequence,
                                          outputs=reconstructed_sequence,
                                          name='autoencoder')

        self.add_loss(tf.nn.l2_loss(input_sequence - reconstructed_sequence))
예제 #2
0
 def testImagesToKeypointsNetShapes(self):
   model = vision.build_images_to_keypoints_net(
       self.cfg, self.data_shapes['image'][1:])
   images = tf.zeros((self.cfg.batch_size,) + self.data_shapes['image'][1:])
   keypoints, heatmaps = model(images)
   self.assertEqual(
       keypoints.shape.as_list(),
       [self.cfg.batch_size, self.time_steps, self.cfg.num_keypoints, 3])
   self.assertEqual(
       heatmaps.shape.as_list(),
       [self.cfg.batch_size, self.time_steps, self.cfg.heatmap_width,
        self.cfg.heatmap_width, 3])
예제 #3
0
def build_model(cfg, data_shapes):
    """Builds the complete model with image encoder plus dynamics model.

  This architecture is meant for testing/illustration only.

  Model architecture:

    image_sequence --> keypoints --> reconstructed_image_sequence
                          |
                          V
                    dynamics_model --> predicted_keypoints

  The model takes a [batch_size, timesteps, H, W, C] image sequence as input. It
  "observes" all frames, detects keypoints, and reconstructs the images. The
  dynamics model learns to predict future keypoints based on the detected
  keypoints.

  Args:
    cfg: ConfigDict with model hyperparameters.
    data_shapes: Dict of shapes of model input tensors, as returned by
      datasets.get_sequence_dataset.
  Returns:
    tf.keras.Model object.
  """
    input_shape_no_batch = data_shapes['image'][
        1:]  # Keras uses shape w/o batch.
    input_images = tf.keras.Input(shape=input_shape_no_batch, name='image')

    # Vision model:
    observed_keypoints, _ = vision.build_images_to_keypoints_net(
        cfg, input_shape_no_batch)(input_images)
    keypoints_to_images_net = vision.build_keypoints_to_images_net(
        cfg, input_shape_no_batch)
    reconstructed_images = keypoints_to_images_net([
        observed_keypoints, input_images[:, 0, Ellipsis],
        observed_keypoints[:, 0, Ellipsis]
    ])

    # Dynamics model:
    observed_keypoints_stop = tf.keras.layers.Lambda(
        tf.stop_gradient)(observed_keypoints)
    dynamics_model = dynamics.build_vrnn(cfg)
    predicted_keypoints, kl_divergence = dynamics_model(
        observed_keypoints_stop)

    model = tf.keras.Model(inputs=[input_images],
                           outputs=[
                               reconstructed_images, observed_keypoints,
                               predicted_keypoints
                           ],
                           name='autoencoder')

    # Losses:
    image_loss = tf.nn.l2_loss(input_images - reconstructed_images)
    # Normalize by batch size and sequence length:
    image_loss /= tf.to_float(
        tf.shape(input_images)[0] * tf.shape(input_images)[1])
    model.add_loss(image_loss)

    separation_loss = losses.temporal_separation_loss(
        cfg, observed_keypoints[:, :cfg.observed_steps, Ellipsis])
    model.add_loss(cfg.separation_loss_scale * separation_loss)

    vrnn_coord_pred_loss = tf.nn.l2_loss(observed_keypoints_stop -
                                         predicted_keypoints)

    # Normalize by batch size and sequence length:
    vrnn_coord_pred_loss /= tf.to_float(
        tf.shape(input_images)[0] * tf.shape(input_images)[1])
    model.add_loss(vrnn_coord_pred_loss)

    kl_loss = tf.reduce_mean(kl_divergence)  # Mean over batch and timesteps.
    model.add_loss(cfg.kl_loss_scale * kl_loss)

    return model