Пример #1
0
  def build_objectmotion_test_graph(self):
    """Builds egomotion model reading from placeholders."""
    input_image_stack_om = tf.placeholder(
        tf.float32,
        [1, self.img_height, self.img_width, self.seq_length * 3],
        name='raw_input')

    if self.imagenet_norm:
      im_mean = tf.tile(
          tf.constant(reader.IMAGENET_MEAN), multiples=[self.seq_length])
      im_sd = tf.tile(
          tf.constant(reader.IMAGENET_SD), multiples=[self.seq_length])
      input_image_stack_om = (input_image_stack_om - im_mean) / im_sd

    with tf.variable_scope('objectmotion_prediction'):
       est_objectmotion = nets.objectmotion_net(
          image_stack=input_image_stack_om,
          disp_bottleneck_stack=None,
          joint_encoder=self.joint_encoder,
          seq_length=self.seq_length,
          weight_reg=self.weight_reg)
    self.input_image_stack_om = input_image_stack_om
    self.est_objectmotion = est_objectmotion
Пример #2
0
  def build_objectmotion_test_graph(self):
    """Builds egomotion model reading from placeholders."""
    input_image_stack_om = tf.placeholder(
        tf.float32,
        [1, self.img_height, self.img_width, self.seq_length * 3],
        name='raw_input')

    if self.imagenet_norm:
      im_mean = tf.tile(
          tf.constant(reader.IMAGENET_MEAN), multiples=[self.seq_length])
      im_sd = tf.tile(
          tf.constant(reader.IMAGENET_SD), multiples=[self.seq_length])
      input_image_stack_om = (input_image_stack_om - im_mean) / im_sd

    with tf.variable_scope('objectmotion_prediction'):
      est_objectmotion = nets.objectmotion_net(
          image_stack=input_image_stack_om,
          disp_bottleneck_stack=None,
          joint_encoder=self.joint_encoder,
          seq_length=self.seq_length,
          weight_reg=self.weight_reg)
    self.input_image_stack_om = input_image_stack_om
    self.est_objectmotion = est_objectmotion
Пример #3
0
  def build_inference_for_training(self):
    """Invokes depth and ego-motion networks and computes clouds if needed."""
    (self.image_stack, self.image_stack_norm, self.seg_stack,
     self.intrinsic_mat, self.intrinsic_mat_inv) = self.reader.read_data()
    with tf.variable_scope('depth_prediction'):
      # Organized by ...[i][scale].  Note that the order is flipped in
      # variables in build_loss() below.
      self.disp = {}
      self.depth = {}
      self.depth_upsampled = {}
      self.inf_loss = 0.0
      # Organized by [i].
      disp_bottlenecks = [None] * self.seq_length

      if self.icp_weight > 0:
        self.cloud = {}
      for i in range(self.seq_length):
        image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]

        multiscale_disps_i, disp_bottlenecks[i] = nets.disp_net(
            self.architecture, image, self.use_skip,
            self.weight_reg, True)
        multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
        self.disp[i] = multiscale_disps_i
        self.depth[i] = multiscale_depths_i
        if self.depth_upsampling:
          self.depth_upsampled[i] = []
          # Upsample low-resolution depth maps using differentiable bilinear
          # interpolation.
          for s in range(len(multiscale_depths_i)):
            self.depth_upsampled[i].append(tf.image.resize_bilinear(
                multiscale_depths_i[s], [self.img_height, self.img_width],
                align_corners=True))

        if self.icp_weight > 0:
          multiscale_clouds_i = [
              project.get_cloud(d,
                                self.intrinsic_mat_inv[:, s, :, :],
                                name='cloud%d_%d' % (s, i))
              for (s, d) in enumerate(multiscale_depths_i)
          ]
          self.cloud[i] = multiscale_clouds_i
        # Reuse the same depth graph for all images.
        tf.get_variable_scope().reuse_variables()

    if self.handle_motion:
      # Define egomotion network. This network can see the whole scene except
      # for any moving objects as indicated by the provided segmentation masks.
      # To avoid the network getting clues of motion by tracking those masks, we
      # define the segmentation masks as the union temporally.
      print('')
      print('')
      print('')
      print('HANDLE MOTION')
      print('')
      print('')
      print('')
      with tf.variable_scope('egomotion_prediction'):
        base_input = self.image_stack_norm  # (B, H, W, 9)
        seg_input = self.seg_stack  # (B, H, W, 9)
        ref_zero = tf.constant(0, dtype=tf.uint8)
        # Motion model is currently defined for three-frame sequences.
        object_mask1 = tf.equal(seg_input[:, :, :, 0], ref_zero)
        object_mask2 = tf.equal(seg_input[:, :, :, 3], ref_zero)
        object_mask3 = tf.equal(seg_input[:, :, :, 6], ref_zero)
        mask_complete = tf.expand_dims(tf.logical_and(  # (B, H, W, 1)
            tf.logical_and(object_mask1, object_mask2), object_mask3), axis=3)
        mask_complete = tf.tile(mask_complete, (1, 1, 1, 9))  # (B, H, W, 9)
        # Now mask out base_input.
        self.mask_complete = tf.to_float(mask_complete)
        self.base_input_masked = base_input * self.mask_complete
        self.egomotion = nets.egomotion_net(
            image_stack=self.base_input_masked,
            disp_bottleneck_stack=None,
            joint_encoder=False,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)

        sess = tf.Session()
        with sess.as_default():
          check_ego = sess.run(egomotion)
          print('')
          print('')
          print('egomotion = ', egomotion)
          print('')
          print('')

        # print('')
        # print('')
        # print('egomotion = ', egomotion)
        # print('')
        # print('')

      # Define object motion network for refinement. This network only sees
      # one object at a time over the whole sequence, and tries to estimate its
      # motion. The sequence of images are the respective warped frames.

      # For each scale, contains batch_size elements of shape (N, 2, 6).
      self.object_transforms = {}
      # For each scale, contains batch_size elements of shape (N, H, W, 9).
      self.object_masks = {}
      self.object_masks_warped = {}
      # For each scale, contains batch_size elements of size N.
      self.object_ids = {}

      self.egomotions_seq = {}
      self.warped_seq = {}
      self.inputs_objectmotion_net = {}
      with tf.variable_scope('objectmotion_prediction'):
        # First, warp raw images according to overall egomotion.
        for s in range(NUM_SCALES):
          self.warped_seq[s] = []
          self.egomotions_seq[s] = []
          for source_index in range(self.seq_length):
            egomotion_mat_i_1 = project.get_transform_mat(
                self.egomotion, source_index, 1)
            warped_image_i_1, _ = (
                project.inverse_warp(
                    self.image_stack[
                        :, :, :, source_index*3:(source_index+1)*3],
                    self.depth_upsampled[1][s],
                    egomotion_mat_i_1,
                    self.intrinsic_mat[:, 0, :, :],
                    self.intrinsic_mat_inv[:, 0, :, :]))

            self.warped_seq[s].append(warped_image_i_1)
            self.egomotions_seq[s].append(egomotion_mat_i_1)

          # Second, for every object in the segmentation mask, take its mask and
          # warp it according to the egomotion estimate. Then put a threshold to
          # binarize the warped result. Use this mask to mask out background and
          # other objects, and pass the filtered image to the object motion
          # network.
          self.object_transforms[s] = []
          self.object_masks[s] = []
          self.object_ids[s] = []
          self.object_masks_warped[s] = []
          self.inputs_objectmotion_net[s] = {}

          for i in range(self.batch_size):
            seg_sequence = self.seg_stack[i]  # (H, W, 9=3*3)
            object_ids = tf.unique(tf.reshape(seg_sequence, [-1]))[0]
            self.object_ids[s].append(object_ids)
            color_stack = []
            mask_stack = []
            mask_stack_warped = []
            for j in range(self.seq_length):
              current_image = self.warped_seq[s][j][i]  # (H, W, 3)
              current_seg = seg_sequence[:, :, j * 3:(j+1) * 3]  # (H, W, 3)

              def process_obj_mask_warp(obj_id):
                """Performs warping of the individual object masks."""
                obj_mask = tf.to_float(tf.equal(current_seg, obj_id))
                # Warp obj_mask according to overall egomotion.
                obj_mask_warped, _ = (
                    project.inverse_warp(
                        tf.expand_dims(obj_mask, axis=0),
                        # Middle frame, highest scale, batch element i:
                        tf.expand_dims(self.depth_upsampled[1][s][i], axis=0),
                        # Matrix for warping j into middle frame, batch elem. i:
                        tf.expand_dims(self.egomotions_seq[s][j][i], axis=0),
                        tf.expand_dims(self.intrinsic_mat[i, 0, :, :], axis=0),
                        tf.expand_dims(self.intrinsic_mat_inv[i, 0, :, :],
                                       axis=0)))
                obj_mask_warped = tf.squeeze(obj_mask_warped)
                obj_mask_binarized = tf.greater(  # Threshold to binarize mask.
                    obj_mask_warped, tf.constant(0.5))
                return tf.to_float(obj_mask_binarized)

              def process_obj_mask(obj_id):
                """Returns the individual object masks separately."""
                return tf.to_float(tf.equal(current_seg, obj_id))
              object_masks = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask, object_ids, dtype=tf.float32)

              if self.size_constraint_weight > 0:
                # The object segmentation masks are all in object_masks.
                # We need to measure the height of every of them, and get the
                # approximate distance.

                # self.depth_upsampled of shape (seq_length, scale, B, H, W).
                depth_pred = self.depth_upsampled[j][s][i]  # (H, W)
                def get_losses(obj_mask):
                  """Get motion constraint loss."""
                  # Find height of segment.
                  coords = tf.where(tf.greater(  # Shape (num_true, 2=yx)
                      obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32)))
                  y_max = tf.reduce_max(coords[:, 0])
                  y_min = tf.reduce_min(coords[:, 0])
                  seg_height = y_max - y_min
                  f_y = self.intrinsic_mat[i, 0, 1, 1]
                  approx_depth = ((f_y * self.global_scale_var) /
                                  tf.to_float(seg_height))
                  reference_pred = tf.boolean_mask(
                      depth_pred, tf.greater(
                          tf.reshape(obj_mask[:, :, 0],
                                     (self.img_height, self.img_width, 1)),
                          tf.constant(0.5, dtype=tf.float32)))

                  # Establish loss on approx_depth, a scalar, and
                  # reference_pred, our dense prediction. Normalize both to
                  # prevent degenerative depth shrinking.
                  global_mean_depth_pred = tf.reduce_mean(depth_pred)
                  reference_pred /= global_mean_depth_pred
                  approx_depth /= global_mean_depth_pred
                  spatial_err = tf.abs(reference_pred - approx_depth)
                  print('')
                  print('')
                  print('spatial error =', spatial_err)
                  print('')
                  print('')
                  #mean_spatial_err = tf.reduce_mean(tf.concat([spatial_err, tf.zeros(1)], axis = 0))
                  mean_spatial_err = tf.reduce_mean(spatial_err)
                  return mean_spatial_err

                losses = tf.map_fn(
                    get_losses, object_masks, dtype=tf.float32)
                print('')
                print('')
                print('Losses = ', losses)
                print('')
                print('')
                self.inf_loss += tf.reduce_mean(losses)
                print('')
                print('')
                print('self.inf_loss = ', self.inf_loss)
                print('')
                print('')
              object_masks_warped = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask_warp, object_ids, dtype=tf.float32)
              filtered_images = tf.map_fn(
                  lambda mask: current_image * mask, object_masks_warped,
                  dtype=tf.float32)  # (N, H, W, 3)
              color_stack.append(filtered_images)
              mask_stack.append(object_masks)
              mask_stack_warped.append(object_masks_warped)

            # For this batch-element, if there are N moving objects,
            # color_stack, mask_stack and mask_stack_warped contain both
            # seq_length elements of shape (N, H, W, 3).
            # We can now concatenate them on the last axis, creating a tensor of
            # (N, H, W, 3*3 = 9), and, assuming N does not get too large so that
            # we have enough memory, pass them in a single batch to the object
            # motion network.
            mask_stack = tf.concat(mask_stack, axis=3)  # (N, H, W, 9)
            mask_stack_warped = tf.concat(mask_stack_warped, axis=3)
            color_stack = tf.concat(color_stack, axis=3)  # (N, H, W, 9)
            all_transforms = nets.objectmotion_net(
                # We cut the gradient flow here as the object motion gradient
                # should have no saying in how the egomotion network behaves.
                # One could try just stopping the gradient for egomotion, but
                # not for the depth prediction network.
                image_stack=tf.stop_gradient(color_stack),
                disp_bottleneck_stack=None,
                joint_encoder=False,  # Joint encoder not supported.
                seq_length=self.seq_length,
                weight_reg=self.weight_reg)
            # all_transforms of shape (N, 2, 6).
            self.object_transforms[s].append(all_transforms)
            self.object_masks[s].append(mask_stack)
            self.object_masks_warped[s].append(mask_stack_warped)
            self.inputs_objectmotion_net[s][i] = color_stack
            tf.get_variable_scope().reuse_variables()
      print('')
      print('')
      print('')
      print('HANDLE MOTION22222')
      print('')
      print('')
      print('')
    else:
      # Don't handle motion, classic model formulation.
      with tf.name_scope('egomotion_prediction'):
        if self.joint_encoder:
          # Re-arrange disp_bottleneck_stack to be of shape
          # [B, h_hid, w_hid, c_hid * seq_length]. Currently, it is a list with
          # seq_length elements, each of dimension [B, h_hid, w_hid, c_hid].
          disp_bottleneck_stack = tf.concat(disp_bottlenecks, axis=3)
        else:
          disp_bottleneck_stack = None
        self.egomotion = nets.egomotion_net(
            image_stack=self.image_stack_norm,
            disp_bottleneck_stack=disp_bottleneck_stack,
            joint_encoder=self.joint_encoder,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)
Пример #4
0
  def build_inference_for_training(self):
    """Invokes depth and ego-motion networks and computes clouds if needed."""
    (self.image_stack, self.image_stack_norm, self.seg_stack,
     self.intrinsic_mat, self.intrinsic_mat_inv) = self.reader.read_data()
    with tf.variable_scope('depth_prediction'):
      # Organized by ...[i][scale].  Note that the order is flipped in
      # variables in build_loss() below.
      self.disp = {}
      self.depth = {}
      self.depth_upsampled = {}
      self.inf_loss = 0.0
      # Organized by [i].
      disp_bottlenecks = [None] * self.seq_length

      if self.icp_weight > 0:
        self.cloud = {}
      for i in range(self.seq_length):
        image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]

        multiscale_disps_i, disp_bottlenecks[i] = nets.disp_net(
            self.architecture, image, self.use_skip,
            self.weight_reg, True)
        multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
        self.disp[i] = multiscale_disps_i
        self.depth[i] = multiscale_depths_i
        if self.depth_upsampling:
          self.depth_upsampled[i] = []
          # Upsample low-resolution depth maps using differentiable bilinear
          # interpolation.
          for s in range(len(multiscale_depths_i)):
            self.depth_upsampled[i].append(tf.image.resize_bilinear(
                multiscale_depths_i[s], [self.img_height, self.img_width],
                align_corners=True))

        if self.icp_weight > 0:
          multiscale_clouds_i = [
              project.get_cloud(d,
                                self.intrinsic_mat_inv[:, s, :, :],
                                name='cloud%d_%d' % (s, i))
              for (s, d) in enumerate(multiscale_depths_i)
          ]
          self.cloud[i] = multiscale_clouds_i
        # Reuse the same depth graph for all images.
        tf.get_variable_scope().reuse_variables()

    if self.handle_motion:
      # Define egomotion network. This network can see the whole scene except
      # for any moving objects as indicated by the provided segmentation masks.
      # To avoid the network getting clues of motion by tracking those masks, we
      # define the segmentation masks as the union temporally.
      with tf.variable_scope('egomotion_prediction'):
        base_input = self.image_stack_norm  # (B, H, W, 9)
        seg_input = self.seg_stack  # (B, H, W, 9)
        ref_zero = tf.constant(0, dtype=tf.uint8)
        # Motion model is currently defined for three-frame sequences.
        object_mask1 = tf.equal(seg_input[:, :, :, 0], ref_zero)
        object_mask2 = tf.equal(seg_input[:, :, :, 3], ref_zero)
        object_mask3 = tf.equal(seg_input[:, :, :, 6], ref_zero)
        mask_complete = tf.expand_dims(tf.logical_and(  # (B, H, W, 1)
            tf.logical_and(object_mask1, object_mask2), object_mask3), axis=3)
        mask_complete = tf.tile(mask_complete, (1, 1, 1, 9))  # (B, H, W, 9)
        # Now mask out base_input.
        self.mask_complete = tf.to_float(mask_complete)
        self.base_input_masked = base_input * self.mask_complete
        self.egomotion = nets.egomotion_net(
            image_stack=self.base_input_masked,
            disp_bottleneck_stack=None,
            joint_encoder=False,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)

      # Define object motion network for refinement. This network only sees
      # one object at a time over the whole sequence, and tries to estimate its
      # motion. The sequence of images are the respective warped frames.

      # For each scale, contains batch_size elements of shape (N, 2, 6).
      self.object_transforms = {}
      # For each scale, contains batch_size elements of shape (N, H, W, 9).
      self.object_masks = {}
      self.object_masks_warped = {}
      # For each scale, contains batch_size elements of size N.
      self.object_ids = {}

      self.egomotions_seq = {}
      self.warped_seq = {}
      self.inputs_objectmotion_net = {}
      with tf.variable_scope('objectmotion_prediction'):
        # First, warp raw images according to overall egomotion.
        for s in range(NUM_SCALES):
          self.warped_seq[s] = []
          self.egomotions_seq[s] = []
          for source_index in range(self.seq_length):
            egomotion_mat_i_1 = project.get_transform_mat(
                self.egomotion, source_index, 1)
            warped_image_i_1, _ = (
                project.inverse_warp(
                    self.image_stack[
                        :, :, :, source_index*3:(source_index+1)*3],
                    self.depth_upsampled[1][s],
                    egomotion_mat_i_1,
                    self.intrinsic_mat[:, 0, :, :],
                    self.intrinsic_mat_inv[:, 0, :, :]))

            self.warped_seq[s].append(warped_image_i_1)
            self.egomotions_seq[s].append(egomotion_mat_i_1)

          # Second, for every object in the segmentation mask, take its mask and
          # warp it according to the egomotion estimate. Then put a threshold to
          # binarize the warped result. Use this mask to mask out background and
          # other objects, and pass the filtered image to the object motion
          # network.
          self.object_transforms[s] = []
          self.object_masks[s] = []
          self.object_ids[s] = []
          self.object_masks_warped[s] = []
          self.inputs_objectmotion_net[s] = {}

          for i in range(self.batch_size):
            seg_sequence = self.seg_stack[i]  # (H, W, 9=3*3)
            object_ids = tf.unique(tf.reshape(seg_sequence, [-1]))[0]
            self.object_ids[s].append(object_ids)
            color_stack = []
            mask_stack = []
            mask_stack_warped = []
            for j in range(self.seq_length):
              current_image = self.warped_seq[s][j][i]  # (H, W, 3)
              current_seg = seg_sequence[:, :, j * 3:(j+1) * 3]  # (H, W, 3)

              def process_obj_mask_warp(obj_id):
                """Performs warping of the individual object masks."""
                obj_mask = tf.to_float(tf.equal(current_seg, obj_id))
                # Warp obj_mask according to overall egomotion.
                obj_mask_warped, _ = (
                    project.inverse_warp(
                        tf.expand_dims(obj_mask, axis=0),
                        # Middle frame, highest scale, batch element i:
                        tf.expand_dims(self.depth_upsampled[1][s][i], axis=0),
                        # Matrix for warping j into middle frame, batch elem. i:
                        tf.expand_dims(self.egomotions_seq[s][j][i], axis=0),
                        tf.expand_dims(self.intrinsic_mat[i, 0, :, :], axis=0),
                        tf.expand_dims(self.intrinsic_mat_inv[i, 0, :, :],
                                       axis=0)))
                obj_mask_warped = tf.squeeze(obj_mask_warped)
                obj_mask_binarized = tf.greater(  # Threshold to binarize mask.
                    obj_mask_warped, tf.constant(0.5))
                return tf.to_float(obj_mask_binarized)

              def process_obj_mask(obj_id):
                """Returns the individual object masks separately."""
                return tf.to_float(tf.equal(current_seg, obj_id))
              object_masks = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask, object_ids, dtype=tf.float32)

              if self.size_constraint_weight > 0:
                # The object segmentation masks are all in object_masks.
                # We need to measure the height of every of them, and get the
                # approximate distance.

                # self.depth_upsampled of shape (seq_length, scale, B, H, W).
                depth_pred = self.depth_upsampled[j][s][i]  # (H, W)
                def get_losses(obj_mask):
                  """Get motion constraint loss."""
                  # Find height of segment.
                  coords = tf.where(tf.greater(  # Shape (num_true, 2=yx)
                      obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32)))
                  y_max = tf.reduce_max(coords[:, 0])
                  y_min = tf.reduce_min(coords[:, 0])
                  seg_height = y_max - y_min
                  f_y = self.intrinsic_mat[i, 0, 1, 1]
                  approx_depth = ((f_y * self.global_scale_var) /
                                  tf.to_float(seg_height))
                  reference_pred = tf.boolean_mask(
                      depth_pred, tf.greater(
                          tf.reshape(obj_mask[:, :, 0],
                                     (self.img_height, self.img_width, 1)),
                          tf.constant(0.5, dtype=tf.float32)))

                  # Establish loss on approx_depth, a scalar, and
                  # reference_pred, our dense prediction. Normalize both to
                  # prevent degenerative depth shrinking.
                  global_mean_depth_pred = tf.reduce_mean(depth_pred)
                  reference_pred /= global_mean_depth_pred
                  approx_depth /= global_mean_depth_pred
                  spatial_err = tf.abs(reference_pred - approx_depth)
                  mean_spatial_err = tf.reduce_mean(spatial_err)
                  return mean_spatial_err

                losses = tf.map_fn(
                    get_losses, object_masks, dtype=tf.float32)
                self.inf_loss += tf.reduce_mean(losses)
              object_masks_warped = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask_warp, object_ids, dtype=tf.float32)
              filtered_images = tf.map_fn(
                  lambda mask: current_image * mask, object_masks_warped,
                  dtype=tf.float32)  # (N, H, W, 3)
              color_stack.append(filtered_images)
              mask_stack.append(object_masks)
              mask_stack_warped.append(object_masks_warped)

            # For this batch-element, if there are N moving objects,
            # color_stack, mask_stack and mask_stack_warped contain both
            # seq_length elements of shape (N, H, W, 3).
            # We can now concatenate them on the last axis, creating a tensor of
            # (N, H, W, 3*3 = 9), and, assuming N does not get too large so that
            # we have enough memory, pass them in a single batch to the object
            # motion network.
            mask_stack = tf.concat(mask_stack, axis=3)  # (N, H, W, 9)
            mask_stack_warped = tf.concat(mask_stack_warped, axis=3)
            color_stack = tf.concat(color_stack, axis=3)  # (N, H, W, 9)
            all_transforms = nets.objectmotion_net(
                # We cut the gradient flow here as the object motion gradient
                # should have no saying in how the egomotion network behaves.
                # One could try just stopping the gradient for egomotion, but
                # not for the depth prediction network.
                image_stack=tf.stop_gradient(color_stack),
                disp_bottleneck_stack=None,
                joint_encoder=False,  # Joint encoder not supported.
                seq_length=self.seq_length,
                weight_reg=self.weight_reg)
            # all_transforms of shape (N, 2, 6).
            self.object_transforms[s].append(all_transforms)
            self.object_masks[s].append(mask_stack)
            self.object_masks_warped[s].append(mask_stack_warped)
            self.inputs_objectmotion_net[s][i] = color_stack
            tf.get_variable_scope().reuse_variables()
    else:
      # Don't handle motion, classic model formulation.
      with tf.name_scope('egomotion_prediction'):
        if self.joint_encoder:
          # Re-arrange disp_bottleneck_stack to be of shape
          # [B, h_hid, w_hid, c_hid * seq_length]. Currently, it is a list with
          # seq_length elements, each of dimension [B, h_hid, w_hid, c_hid].
          disp_bottleneck_stack = tf.concat(disp_bottlenecks, axis=3)
        else:
          disp_bottleneck_stack = None
        self.egomotion = nets.egomotion_net(
            image_stack=self.image_stack_norm,
            disp_bottleneck_stack=disp_bottleneck_stack,
            joint_encoder=self.joint_encoder,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)