Esempio n. 1
0
 def build_inference_for_training(self):
   """Invokes depth and ego-motion networks and computes clouds if needed."""
   (self.image_stack, self.intrinsic_mat, self.intrinsic_mat_inv) = (
       self.reader.read_data())
   with tf.name_scope('egomotion_prediction'):
     self.egomotion, _ = nets.egomotion_net(self.image_stack, is_training=True,
                                            legacy_mode=self.legacy_mode)
   with tf.variable_scope('depth_prediction'):
     # Organized by ...[i][scale].  Note that the order is flipped in
     # variables in build_loss() below.
     self.disp = {}
     self.depth = {}
     if self.icp_weight > 0:
       self.cloud = {}
     for i in range(self.seq_length):
       image = self.image_stack[:, :, :, 3 * i:3 * (i + 1)]
       multiscale_disps_i, _ = nets.disp_net(image, is_training=True)
       multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
       self.disp[i] = multiscale_disps_i
       self.depth[i] = multiscale_depths_i
       if self.icp_weight > 0:
         multiscale_clouds_i = [
             project.get_cloud(d,
                               self.intrinsic_mat_inv[:, s, :, :],
                               name='cloud%d_%d' % (s, i))
             for (s, d) in enumerate(multiscale_depths_i)
         ]
         self.cloud[i] = multiscale_clouds_i
       # Reuse the same depth graph for all images.
       tf.get_variable_scope().reuse_variables()
   logging.info('disp: %s', util.info(self.disp))
Esempio n. 2
0
 def build_inference_for_training(self):
     """Invokes depth and ego-motion networks and computes clouds if needed."""
     (self.image_stack, self.intrinsic_mat,
      self.intrinsic_mat_inv) = (self.reader.read_data())
     with tf.name_scope('egomotion_prediction'):
         self.egomotion, _ = nets.egomotion_net(
             self.image_stack,
             is_training=True,
             legacy_mode=self.legacy_mode)
     with tf.variable_scope('depth_prediction'):
         # Organized by ...[i][scale].  Note that the order is flipped in
         # variables in build_loss() below.
         self.disp = {}
         self.depth = {}
         if self.icp_weight > 0:
             self.cloud = {}
         for i in range(self.seq_length):
             image = self.image_stack[:, :, :, 3 * i:3 * (i + 1)]
             multiscale_disps_i, _ = nets.disp_net(image, is_training=True)
             multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
             self.disp[i] = multiscale_disps_i
             self.depth[i] = multiscale_depths_i
             if self.icp_weight > 0:
                 multiscale_clouds_i = [
                     project.get_cloud(d,
                                       self.intrinsic_mat_inv[:, s, :, :],
                                       name='cloud%d_%d' % (s, i))
                     for (s, d) in enumerate(multiscale_depths_i)
                 ]
                 self.cloud[i] = multiscale_clouds_i
             # Reuse the same depth graph for all images.
             tf.get_variable_scope().reuse_variables()
     logging.info('disp: %s', util.info(self.disp))
Esempio n. 3
0
 def build_depth_test_graph(self):
     with tf.name_scope('depth_prediction'):
         with tf.variable_scope('depth_prediction'):
             input_uint8 = tf.placeholder(tf.uint8, [self.batch_size, self.img_height, self.img_width, 3],name='raw_input')
             input_float = tf.image.convert_image_dtype(input_uint8, tf.float32)
             # TODO(rezama): Retrain published model with batchnorm params and set
             # is_training to False.
             est_disp, _ = nets.disp_net(input_float, is_training=True)
             est_depth = 1.0 / est_disp[0]
     self.inputs_depth = input_uint8
     self.est_depth = est_depth
Esempio n. 4
0
 def build_depth_test_graph(self):
     input_uint8 = tf.placeholder(tf.uint8, [self.batch_size, 
                 self.img_height, self.img_width, 3], name='raw_input')
     input_mc = self.preprocess_image(input_uint8)
     with tf.name_scope("depth_prediction"):
         pred_disp, depth_net_endpoints = disp_net(
             input_mc, is_training=False)
         pred_depth = [1./disp for disp in pred_disp]
     pred_depth = pred_depth[0]
     self.inputs = input_uint8
     self.pred_depth = pred_depth
     self.depth_epts = depth_net_endpoints
Esempio n. 5
0
 def build_depth_test_graph(self):
   """Builds depth model reading from placeholders."""
   with tf.name_scope('depth_prediction'):
     with tf.variable_scope('depth_prediction'):
       input_uint8 = tf.placeholder(
           tf.uint8, [self.batch_size, self.img_height, self.img_width, 3],
           name='raw_input')
       input_float = tf.image.convert_image_dtype(input_uint8, tf.float32)
       # TODO(rezama): Retrain published model with batchnorm params and set
       # is_training to False.
       est_disp, _ = nets.disp_net(input_float, is_training=True)
       est_depth = 1.0 / est_disp[0]
   self.inputs_depth = input_uint8
   self.est_depth = est_depth
Esempio n. 6
0
 def build_depth_test_graph(self):
   """Builds depth model reading from placeholders."""
   with tf.variable_scope('depth_prediction'):
     input_image = tf.placeholder(
         tf.float32, [self.batch_size, self.img_height, self.img_width, 3],
         name='raw_input')
     if self.imagenet_norm:
       input_image = (input_image - reader.IMAGENET_MEAN) / reader.IMAGENET_SD
     est_disp, _ = nets.disp_net(architecture=self.architecture,
                                 image=input_image,
                                 use_skip=self.use_skip,
                                 weight_reg=self.weight_reg,
                                 is_training=True)
   est_depth = 1.0 / est_disp[0]
   self.input_image = input_image
   self.est_depth = est_depth
Esempio n. 7
0
 def build_depth_test_graph(self):
   """Builds depth model reading from placeholders."""
   with tf.variable_scope('depth_prediction'):
     input_image = tf.placeholder(
         tf.float32, [self.batch_size, self.img_height, self.img_width, 3],
         name='raw_input')
     if self.imagenet_norm:
       input_image = (input_image - reader.IMAGENET_MEAN) / reader.IMAGENET_SD
     est_disp, _ = nets.disp_net(architecture=self.architecture,
                                 image=input_image,
                                 use_skip=self.use_skip,
                                 weight_reg=self.weight_reg,
                                 is_training=True)
   est_depth = 1.0 / est_disp[0]
   self.input_image = input_image
   self.est_depth = est_depth
Esempio n. 8
0
 def build_inference_for_training(self):
     (self.image_stack, self.intrinsic_mat, self.intrinsic_mat_inv) = (self.reader.read_data())
     with tf.name_scope('egomotion_prediction'):
         self.egomotion, _ = nets.egomotion_net(self.image_stack, is_training=True, legacy_mode=self.legacy_mode)
     with tf.variable_scope('depth_prediction'):
         # Organized by ...[i][scale].  Note that the order is flipped in
         # variables in build_loss() below.
         self.disp = {}
         self.depth = {}
         for i in range(self.seq_length):
             image = self.image_stack[:, :, :, 3 * i:3 * (i + 1)]
             multiscale_disps_i, _ = nets.disp_net(image, is_training=True)
             multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
             self.disp[i] = multiscale_disps_i
             self.depth[i] = multiscale_depths_i
             # Reuse the same depth graph for all images.
             tf.get_variable_scope().reuse_variables()
Esempio n. 9
0
    def build_single_depth_test_graph(self):
        """Assume batch size is 1"""
        with tf.variable_scope('depth_prediction'):
            input_image = tf.placeholder(
                tf.float32, [1, self.img_height, self.img_width, 3],
                name='raw_input')
            if self.imagenet_norm:
                input_image = (input_image - reader.IMAGENET_MEAN) / reader.IMAGENET_SD

            tf.get_variable_scope().reuse_variables()  # Note: reuse variable
            est_disp, _ = nets.disp_net(architecture=self.architecture,
                                        image=input_image,
                                        use_skip=self.use_skip,
                                        weight_reg=self.weight_reg,
                                        is_training=True)

        est_depth = 1.0 / est_disp[0]
        self.input_image = input_image
        self.est_depth = est_depth
Esempio n. 10
0
  def build_inference_for_training(self):
    """Invokes depth and ego-motion networks and computes clouds if needed."""
    (self.image_stack, self.image_stack_norm, self.seg_stack,
     self.intrinsic_mat, self.intrinsic_mat_inv) = self.reader.read_data()
    with tf.variable_scope('depth_prediction'):
      # Organized by ...[i][scale].  Note that the order is flipped in
      # variables in build_loss() below.
      self.disp = {}
      self.depth = {}
      self.depth_upsampled = {}
      self.inf_loss = 0.0
      # Organized by [i].
      disp_bottlenecks = [None] * self.seq_length

      if self.icp_weight > 0:
        self.cloud = {}
      for i in range(self.seq_length):
        image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]

        multiscale_disps_i, disp_bottlenecks[i] = nets.disp_net(
            self.architecture, image, self.use_skip,
            self.weight_reg, True)
        multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
        self.disp[i] = multiscale_disps_i
        self.depth[i] = multiscale_depths_i
        if self.depth_upsampling:
          self.depth_upsampled[i] = []
          # Upsample low-resolution depth maps using differentiable bilinear
          # interpolation.
          for s in range(len(multiscale_depths_i)):
            self.depth_upsampled[i].append(tf.image.resize_bilinear(
                multiscale_depths_i[s], [self.img_height, self.img_width],
                align_corners=True))

        if self.icp_weight > 0:
          multiscale_clouds_i = [
              project.get_cloud(d,
                                self.intrinsic_mat_inv[:, s, :, :],
                                name='cloud%d_%d' % (s, i))
              for (s, d) in enumerate(multiscale_depths_i)
          ]
          self.cloud[i] = multiscale_clouds_i
        # Reuse the same depth graph for all images.
        tf.get_variable_scope().reuse_variables()

    if self.handle_motion:
      # Define egomotion network. This network can see the whole scene except
      # for any moving objects as indicated by the provided segmentation masks.
      # To avoid the network getting clues of motion by tracking those masks, we
      # define the segmentation masks as the union temporally.
      print('')
      print('')
      print('')
      print('HANDLE MOTION')
      print('')
      print('')
      print('')
      with tf.variable_scope('egomotion_prediction'):
        base_input = self.image_stack_norm  # (B, H, W, 9)
        seg_input = self.seg_stack  # (B, H, W, 9)
        ref_zero = tf.constant(0, dtype=tf.uint8)
        # Motion model is currently defined for three-frame sequences.
        object_mask1 = tf.equal(seg_input[:, :, :, 0], ref_zero)
        object_mask2 = tf.equal(seg_input[:, :, :, 3], ref_zero)
        object_mask3 = tf.equal(seg_input[:, :, :, 6], ref_zero)
        mask_complete = tf.expand_dims(tf.logical_and(  # (B, H, W, 1)
            tf.logical_and(object_mask1, object_mask2), object_mask3), axis=3)
        mask_complete = tf.tile(mask_complete, (1, 1, 1, 9))  # (B, H, W, 9)
        # Now mask out base_input.
        self.mask_complete = tf.to_float(mask_complete)
        self.base_input_masked = base_input * self.mask_complete
        self.egomotion = nets.egomotion_net(
            image_stack=self.base_input_masked,
            disp_bottleneck_stack=None,
            joint_encoder=False,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)

        sess = tf.Session()
        with sess.as_default():
          check_ego = sess.run(egomotion)
          print('')
          print('')
          print('egomotion = ', egomotion)
          print('')
          print('')

        # print('')
        # print('')
        # print('egomotion = ', egomotion)
        # print('')
        # print('')

      # Define object motion network for refinement. This network only sees
      # one object at a time over the whole sequence, and tries to estimate its
      # motion. The sequence of images are the respective warped frames.

      # For each scale, contains batch_size elements of shape (N, 2, 6).
      self.object_transforms = {}
      # For each scale, contains batch_size elements of shape (N, H, W, 9).
      self.object_masks = {}
      self.object_masks_warped = {}
      # For each scale, contains batch_size elements of size N.
      self.object_ids = {}

      self.egomotions_seq = {}
      self.warped_seq = {}
      self.inputs_objectmotion_net = {}
      with tf.variable_scope('objectmotion_prediction'):
        # First, warp raw images according to overall egomotion.
        for s in range(NUM_SCALES):
          self.warped_seq[s] = []
          self.egomotions_seq[s] = []
          for source_index in range(self.seq_length):
            egomotion_mat_i_1 = project.get_transform_mat(
                self.egomotion, source_index, 1)
            warped_image_i_1, _ = (
                project.inverse_warp(
                    self.image_stack[
                        :, :, :, source_index*3:(source_index+1)*3],
                    self.depth_upsampled[1][s],
                    egomotion_mat_i_1,
                    self.intrinsic_mat[:, 0, :, :],
                    self.intrinsic_mat_inv[:, 0, :, :]))

            self.warped_seq[s].append(warped_image_i_1)
            self.egomotions_seq[s].append(egomotion_mat_i_1)

          # Second, for every object in the segmentation mask, take its mask and
          # warp it according to the egomotion estimate. Then put a threshold to
          # binarize the warped result. Use this mask to mask out background and
          # other objects, and pass the filtered image to the object motion
          # network.
          self.object_transforms[s] = []
          self.object_masks[s] = []
          self.object_ids[s] = []
          self.object_masks_warped[s] = []
          self.inputs_objectmotion_net[s] = {}

          for i in range(self.batch_size):
            seg_sequence = self.seg_stack[i]  # (H, W, 9=3*3)
            object_ids = tf.unique(tf.reshape(seg_sequence, [-1]))[0]
            self.object_ids[s].append(object_ids)
            color_stack = []
            mask_stack = []
            mask_stack_warped = []
            for j in range(self.seq_length):
              current_image = self.warped_seq[s][j][i]  # (H, W, 3)
              current_seg = seg_sequence[:, :, j * 3:(j+1) * 3]  # (H, W, 3)

              def process_obj_mask_warp(obj_id):
                """Performs warping of the individual object masks."""
                obj_mask = tf.to_float(tf.equal(current_seg, obj_id))
                # Warp obj_mask according to overall egomotion.
                obj_mask_warped, _ = (
                    project.inverse_warp(
                        tf.expand_dims(obj_mask, axis=0),
                        # Middle frame, highest scale, batch element i:
                        tf.expand_dims(self.depth_upsampled[1][s][i], axis=0),
                        # Matrix for warping j into middle frame, batch elem. i:
                        tf.expand_dims(self.egomotions_seq[s][j][i], axis=0),
                        tf.expand_dims(self.intrinsic_mat[i, 0, :, :], axis=0),
                        tf.expand_dims(self.intrinsic_mat_inv[i, 0, :, :],
                                       axis=0)))
                obj_mask_warped = tf.squeeze(obj_mask_warped)
                obj_mask_binarized = tf.greater(  # Threshold to binarize mask.
                    obj_mask_warped, tf.constant(0.5))
                return tf.to_float(obj_mask_binarized)

              def process_obj_mask(obj_id):
                """Returns the individual object masks separately."""
                return tf.to_float(tf.equal(current_seg, obj_id))
              object_masks = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask, object_ids, dtype=tf.float32)

              if self.size_constraint_weight > 0:
                # The object segmentation masks are all in object_masks.
                # We need to measure the height of every of them, and get the
                # approximate distance.

                # self.depth_upsampled of shape (seq_length, scale, B, H, W).
                depth_pred = self.depth_upsampled[j][s][i]  # (H, W)
                def get_losses(obj_mask):
                  """Get motion constraint loss."""
                  # Find height of segment.
                  coords = tf.where(tf.greater(  # Shape (num_true, 2=yx)
                      obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32)))
                  y_max = tf.reduce_max(coords[:, 0])
                  y_min = tf.reduce_min(coords[:, 0])
                  seg_height = y_max - y_min
                  f_y = self.intrinsic_mat[i, 0, 1, 1]
                  approx_depth = ((f_y * self.global_scale_var) /
                                  tf.to_float(seg_height))
                  reference_pred = tf.boolean_mask(
                      depth_pred, tf.greater(
                          tf.reshape(obj_mask[:, :, 0],
                                     (self.img_height, self.img_width, 1)),
                          tf.constant(0.5, dtype=tf.float32)))

                  # Establish loss on approx_depth, a scalar, and
                  # reference_pred, our dense prediction. Normalize both to
                  # prevent degenerative depth shrinking.
                  global_mean_depth_pred = tf.reduce_mean(depth_pred)
                  reference_pred /= global_mean_depth_pred
                  approx_depth /= global_mean_depth_pred
                  spatial_err = tf.abs(reference_pred - approx_depth)
                  print('')
                  print('')
                  print('spatial error =', spatial_err)
                  print('')
                  print('')
                  #mean_spatial_err = tf.reduce_mean(tf.concat([spatial_err, tf.zeros(1)], axis = 0))
                  mean_spatial_err = tf.reduce_mean(spatial_err)
                  return mean_spatial_err

                losses = tf.map_fn(
                    get_losses, object_masks, dtype=tf.float32)
                print('')
                print('')
                print('Losses = ', losses)
                print('')
                print('')
                self.inf_loss += tf.reduce_mean(losses)
                print('')
                print('')
                print('self.inf_loss = ', self.inf_loss)
                print('')
                print('')
              object_masks_warped = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask_warp, object_ids, dtype=tf.float32)
              filtered_images = tf.map_fn(
                  lambda mask: current_image * mask, object_masks_warped,
                  dtype=tf.float32)  # (N, H, W, 3)
              color_stack.append(filtered_images)
              mask_stack.append(object_masks)
              mask_stack_warped.append(object_masks_warped)

            # For this batch-element, if there are N moving objects,
            # color_stack, mask_stack and mask_stack_warped contain both
            # seq_length elements of shape (N, H, W, 3).
            # We can now concatenate them on the last axis, creating a tensor of
            # (N, H, W, 3*3 = 9), and, assuming N does not get too large so that
            # we have enough memory, pass them in a single batch to the object
            # motion network.
            mask_stack = tf.concat(mask_stack, axis=3)  # (N, H, W, 9)
            mask_stack_warped = tf.concat(mask_stack_warped, axis=3)
            color_stack = tf.concat(color_stack, axis=3)  # (N, H, W, 9)
            all_transforms = nets.objectmotion_net(
                # We cut the gradient flow here as the object motion gradient
                # should have no saying in how the egomotion network behaves.
                # One could try just stopping the gradient for egomotion, but
                # not for the depth prediction network.
                image_stack=tf.stop_gradient(color_stack),
                disp_bottleneck_stack=None,
                joint_encoder=False,  # Joint encoder not supported.
                seq_length=self.seq_length,
                weight_reg=self.weight_reg)
            # all_transforms of shape (N, 2, 6).
            self.object_transforms[s].append(all_transforms)
            self.object_masks[s].append(mask_stack)
            self.object_masks_warped[s].append(mask_stack_warped)
            self.inputs_objectmotion_net[s][i] = color_stack
            tf.get_variable_scope().reuse_variables()
      print('')
      print('')
      print('')
      print('HANDLE MOTION22222')
      print('')
      print('')
      print('')
    else:
      # Don't handle motion, classic model formulation.
      with tf.name_scope('egomotion_prediction'):
        if self.joint_encoder:
          # Re-arrange disp_bottleneck_stack to be of shape
          # [B, h_hid, w_hid, c_hid * seq_length]. Currently, it is a list with
          # seq_length elements, each of dimension [B, h_hid, w_hid, c_hid].
          disp_bottleneck_stack = tf.concat(disp_bottlenecks, axis=3)
        else:
          disp_bottleneck_stack = None
        self.egomotion = nets.egomotion_net(
            image_stack=self.image_stack_norm,
            disp_bottleneck_stack=disp_bottleneck_stack,
            joint_encoder=self.joint_encoder,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)
Esempio n. 11
0
  def build_inference_for_training(self):
    """Invokes depth and ego-motion networks and computes clouds if needed."""
    (self.image_stack, self.image_stack_norm, self.seg_stack,
     self.intrinsic_mat, self.intrinsic_mat_inv) = self.reader.read_data()
    with tf.variable_scope('depth_prediction'):
      # Organized by ...[i][scale].  Note that the order is flipped in
      # variables in build_loss() below.
      self.disp = {}
      self.depth = {}
      self.depth_upsampled = {}
      self.inf_loss = 0.0
      # Organized by [i].
      disp_bottlenecks = [None] * self.seq_length

      if self.icp_weight > 0:
        self.cloud = {}
      for i in range(self.seq_length):
        image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]

        multiscale_disps_i, disp_bottlenecks[i] = nets.disp_net(
            self.architecture, image, self.use_skip,
            self.weight_reg, True)
        multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
        self.disp[i] = multiscale_disps_i
        self.depth[i] = multiscale_depths_i
        if self.depth_upsampling:
          self.depth_upsampled[i] = []
          # Upsample low-resolution depth maps using differentiable bilinear
          # interpolation.
          for s in range(len(multiscale_depths_i)):
            self.depth_upsampled[i].append(tf.image.resize_bilinear(
                multiscale_depths_i[s], [self.img_height, self.img_width],
                align_corners=True))

        if self.icp_weight > 0:
          multiscale_clouds_i = [
              project.get_cloud(d,
                                self.intrinsic_mat_inv[:, s, :, :],
                                name='cloud%d_%d' % (s, i))
              for (s, d) in enumerate(multiscale_depths_i)
          ]
          self.cloud[i] = multiscale_clouds_i
        # Reuse the same depth graph for all images.
        tf.get_variable_scope().reuse_variables()

    if self.handle_motion:
      # Define egomotion network. This network can see the whole scene except
      # for any moving objects as indicated by the provided segmentation masks.
      # To avoid the network getting clues of motion by tracking those masks, we
      # define the segmentation masks as the union temporally.
      with tf.variable_scope('egomotion_prediction'):
        base_input = self.image_stack_norm  # (B, H, W, 9)
        seg_input = self.seg_stack  # (B, H, W, 9)
        ref_zero = tf.constant(0, dtype=tf.uint8)
        # Motion model is currently defined for three-frame sequences.
        object_mask1 = tf.equal(seg_input[:, :, :, 0], ref_zero)
        object_mask2 = tf.equal(seg_input[:, :, :, 3], ref_zero)
        object_mask3 = tf.equal(seg_input[:, :, :, 6], ref_zero)
        mask_complete = tf.expand_dims(tf.logical_and(  # (B, H, W, 1)
            tf.logical_and(object_mask1, object_mask2), object_mask3), axis=3)
        mask_complete = tf.tile(mask_complete, (1, 1, 1, 9))  # (B, H, W, 9)
        # Now mask out base_input.
        self.mask_complete = tf.to_float(mask_complete)
        self.base_input_masked = base_input * self.mask_complete
        self.egomotion = nets.egomotion_net(
            image_stack=self.base_input_masked,
            disp_bottleneck_stack=None,
            joint_encoder=False,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)

      # Define object motion network for refinement. This network only sees
      # one object at a time over the whole sequence, and tries to estimate its
      # motion. The sequence of images are the respective warped frames.

      # For each scale, contains batch_size elements of shape (N, 2, 6).
      self.object_transforms = {}
      # For each scale, contains batch_size elements of shape (N, H, W, 9).
      self.object_masks = {}
      self.object_masks_warped = {}
      # For each scale, contains batch_size elements of size N.
      self.object_ids = {}

      self.egomotions_seq = {}
      self.warped_seq = {}
      self.inputs_objectmotion_net = {}
      with tf.variable_scope('objectmotion_prediction'):
        # First, warp raw images according to overall egomotion.
        for s in range(NUM_SCALES):
          self.warped_seq[s] = []
          self.egomotions_seq[s] = []
          for source_index in range(self.seq_length):
            egomotion_mat_i_1 = project.get_transform_mat(
                self.egomotion, source_index, 1)
            warped_image_i_1, _ = (
                project.inverse_warp(
                    self.image_stack[
                        :, :, :, source_index*3:(source_index+1)*3],
                    self.depth_upsampled[1][s],
                    egomotion_mat_i_1,
                    self.intrinsic_mat[:, 0, :, :],
                    self.intrinsic_mat_inv[:, 0, :, :]))

            self.warped_seq[s].append(warped_image_i_1)
            self.egomotions_seq[s].append(egomotion_mat_i_1)

          # Second, for every object in the segmentation mask, take its mask and
          # warp it according to the egomotion estimate. Then put a threshold to
          # binarize the warped result. Use this mask to mask out background and
          # other objects, and pass the filtered image to the object motion
          # network.
          self.object_transforms[s] = []
          self.object_masks[s] = []
          self.object_ids[s] = []
          self.object_masks_warped[s] = []
          self.inputs_objectmotion_net[s] = {}

          for i in range(self.batch_size):
            seg_sequence = self.seg_stack[i]  # (H, W, 9=3*3)
            object_ids = tf.unique(tf.reshape(seg_sequence, [-1]))[0]
            self.object_ids[s].append(object_ids)
            color_stack = []
            mask_stack = []
            mask_stack_warped = []
            for j in range(self.seq_length):
              current_image = self.warped_seq[s][j][i]  # (H, W, 3)
              current_seg = seg_sequence[:, :, j * 3:(j+1) * 3]  # (H, W, 3)

              def process_obj_mask_warp(obj_id):
                """Performs warping of the individual object masks."""
                obj_mask = tf.to_float(tf.equal(current_seg, obj_id))
                # Warp obj_mask according to overall egomotion.
                obj_mask_warped, _ = (
                    project.inverse_warp(
                        tf.expand_dims(obj_mask, axis=0),
                        # Middle frame, highest scale, batch element i:
                        tf.expand_dims(self.depth_upsampled[1][s][i], axis=0),
                        # Matrix for warping j into middle frame, batch elem. i:
                        tf.expand_dims(self.egomotions_seq[s][j][i], axis=0),
                        tf.expand_dims(self.intrinsic_mat[i, 0, :, :], axis=0),
                        tf.expand_dims(self.intrinsic_mat_inv[i, 0, :, :],
                                       axis=0)))
                obj_mask_warped = tf.squeeze(obj_mask_warped)
                obj_mask_binarized = tf.greater(  # Threshold to binarize mask.
                    obj_mask_warped, tf.constant(0.5))
                return tf.to_float(obj_mask_binarized)

              def process_obj_mask(obj_id):
                """Returns the individual object masks separately."""
                return tf.to_float(tf.equal(current_seg, obj_id))
              object_masks = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask, object_ids, dtype=tf.float32)

              if self.size_constraint_weight > 0:
                # The object segmentation masks are all in object_masks.
                # We need to measure the height of every of them, and get the
                # approximate distance.

                # self.depth_upsampled of shape (seq_length, scale, B, H, W).
                depth_pred = self.depth_upsampled[j][s][i]  # (H, W)
                def get_losses(obj_mask):
                  """Get motion constraint loss."""
                  # Find height of segment.
                  coords = tf.where(tf.greater(  # Shape (num_true, 2=yx)
                      obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32)))
                  y_max = tf.reduce_max(coords[:, 0])
                  y_min = tf.reduce_min(coords[:, 0])
                  seg_height = y_max - y_min
                  f_y = self.intrinsic_mat[i, 0, 1, 1]
                  approx_depth = ((f_y * self.global_scale_var) /
                                  tf.to_float(seg_height))
                  reference_pred = tf.boolean_mask(
                      depth_pred, tf.greater(
                          tf.reshape(obj_mask[:, :, 0],
                                     (self.img_height, self.img_width, 1)),
                          tf.constant(0.5, dtype=tf.float32)))

                  # Establish loss on approx_depth, a scalar, and
                  # reference_pred, our dense prediction. Normalize both to
                  # prevent degenerative depth shrinking.
                  global_mean_depth_pred = tf.reduce_mean(depth_pred)
                  reference_pred /= global_mean_depth_pred
                  approx_depth /= global_mean_depth_pred
                  spatial_err = tf.abs(reference_pred - approx_depth)
                  mean_spatial_err = tf.reduce_mean(spatial_err)
                  return mean_spatial_err

                losses = tf.map_fn(
                    get_losses, object_masks, dtype=tf.float32)
                self.inf_loss += tf.reduce_mean(losses)
              object_masks_warped = tf.map_fn(  # (N, H, W, 3)
                  process_obj_mask_warp, object_ids, dtype=tf.float32)
              filtered_images = tf.map_fn(
                  lambda mask: current_image * mask, object_masks_warped,
                  dtype=tf.float32)  # (N, H, W, 3)
              color_stack.append(filtered_images)
              mask_stack.append(object_masks)
              mask_stack_warped.append(object_masks_warped)

            # For this batch-element, if there are N moving objects,
            # color_stack, mask_stack and mask_stack_warped contain both
            # seq_length elements of shape (N, H, W, 3).
            # We can now concatenate them on the last axis, creating a tensor of
            # (N, H, W, 3*3 = 9), and, assuming N does not get too large so that
            # we have enough memory, pass them in a single batch to the object
            # motion network.
            mask_stack = tf.concat(mask_stack, axis=3)  # (N, H, W, 9)
            mask_stack_warped = tf.concat(mask_stack_warped, axis=3)
            color_stack = tf.concat(color_stack, axis=3)  # (N, H, W, 9)
            all_transforms = nets.objectmotion_net(
                # We cut the gradient flow here as the object motion gradient
                # should have no saying in how the egomotion network behaves.
                # One could try just stopping the gradient for egomotion, but
                # not for the depth prediction network.
                image_stack=tf.stop_gradient(color_stack),
                disp_bottleneck_stack=None,
                joint_encoder=False,  # Joint encoder not supported.
                seq_length=self.seq_length,
                weight_reg=self.weight_reg)
            # all_transforms of shape (N, 2, 6).
            self.object_transforms[s].append(all_transforms)
            self.object_masks[s].append(mask_stack)
            self.object_masks_warped[s].append(mask_stack_warped)
            self.inputs_objectmotion_net[s][i] = color_stack
            tf.get_variable_scope().reuse_variables()
    else:
      # Don't handle motion, classic model formulation.
      with tf.name_scope('egomotion_prediction'):
        if self.joint_encoder:
          # Re-arrange disp_bottleneck_stack to be of shape
          # [B, h_hid, w_hid, c_hid * seq_length]. Currently, it is a list with
          # seq_length elements, each of dimension [B, h_hid, w_hid, c_hid].
          disp_bottleneck_stack = tf.concat(disp_bottlenecks, axis=3)
        else:
          disp_bottleneck_stack = None
        self.egomotion = nets.egomotion_net(
            image_stack=self.image_stack_norm,
            disp_bottleneck_stack=disp_bottleneck_stack,
            joint_encoder=self.joint_encoder,
            seq_length=self.seq_length,
            weight_reg=self.weight_reg)
Esempio n. 12
0
    def build_inference_for_training(self):
        """Invokes depth and ego-motion networks."""
        if self.is_training:
            (self.image_stack, self.image_stack_norm, self.seg_stack,
             self.intrinsic_mat, self.intrinsic_mat_inv) = self.reader.read_data()

        with tf.variable_scope('depth_prediction'):
            # Organized by ...[i][scale].  Note that the order is flipped in
            # variables in build_loss() below.
            self.disp = {}
            self.depth = {}
            self.depth_upsampled = {}
            self.object_depth_loss = 0.0
            # Organized by [i].
            disp_bottlenecks = [None] * self.seq_length

            for i in range(self.seq_length):
                image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]

                multiscale_disps_i, disp_bottlenecks[i] = nets.disp_net(
                    self.architecture, image, self.use_skip,
                    self.weight_reg, True)

                multiscale_depths_i = [1.0 / d for d in multiscale_disps_i]
                self.disp[i] = multiscale_disps_i
                self.depth[i] = multiscale_depths_i
                if self.depth_upsampling:
                    self.depth_upsampled[i] = []
                    # Upsample low-resolution depth maps using differentiable bilinear
                    # interpolation.
                    for s in range(len(multiscale_depths_i)):
                        self.depth_upsampled[i].append(tf.image.resize_bilinear(
                            multiscale_depths_i[s], [self.img_height, self.img_width],
                            align_corners=True))

                # Reuse the same depth graph for all images.
                tf.get_variable_scope().reuse_variables()

        if self.handle_motion:
            # Define egomotion network. This network can see the whole scene except
            # for any moving objects as indicated by the provided segmentation masks.
            # To avoid the network getting clues of motion by tracking those masks, we
            # define the segmentation masks as the union temporally.
            with tf.variable_scope('egomotion_prediction'):
                base_input = self.image_stack_norm  # (B, H, W, 9)
                seg_input = self.seg_stack  # (B, H, W, 9)
                ref_zero = tf.constant(0, dtype=tf.uint8)
                # Motion model is currently defined for three-frame sequences.
                object_mask1 = tf.equal(seg_input[:, :, :, 0], ref_zero)
                object_mask2 = tf.equal(seg_input[:, :, :, 3], ref_zero)
                object_mask3 = tf.equal(seg_input[:, :, :, 6], ref_zero)
                mask_complete = tf.expand_dims(tf.logical_and(  # (B, H, W, 1)
                    tf.logical_and(object_mask1, object_mask2), object_mask3), axis=3)
                mask_complete = tf.tile(mask_complete, (1, 1, 1, 9))  # (B, H, W, 9)
                # Now mask out base_input.
                self.mask_complete = tf.to_float(mask_complete)
                self.base_input_masked = base_input * self.mask_complete  # [B, H, W, 9]

                self.egomotion = nets.egomotion_net(
                    image_stack=self.base_input_masked,
                    disp_bottleneck_stack=None,
                    joint_encoder=False,
                    seq_length=self.seq_length,
                    weight_reg=self.weight_reg,
                    same_trans_rot_scaling=self.same_trans_rot_scaling)

            # Define object motion network for refinement. This network only sees
            # one object at a time over the whole sequence, and tries to estimate its
            # motion. The sequence of images are the respective warped frames.

            # For each scale, contains batch_size elements of shape (N, 2, 6).
            self.object_transforms = {}
            # For each scale, contains batch_size elements of shape (N, H, W, 9).
            self.object_masks = {}
            self.object_masks_warped = {}
            # For each scale, contains batch_size elements of size N.
            self.object_ids = {}

            self.egomotions_seq = {}
            self.warped_seq = {}
            # For each scale, contains 3 elements of shape [B, H, W, 2]
            self.rigid_flow_seq = {}
            self.inputs_region_deformer_net = {}
            with tf.variable_scope('objectmotion_prediction'):
                # First, warp raw images according to overall egomotion.
                for s in range(NUM_SCALES):
                    self.warped_seq[s] = []
                    self.rigid_flow_seq[s] = []
                    self.egomotions_seq[s] = []
                    for source_index in range(self.seq_length):
                        egomotion_mat_i_1 = project.get_transform_mat(
                            self.egomotion, source_index, 1, use_axis_angle=self.use_axis_angle)

                        # The gradient of egomotion network should only comes from background,
                        # stop gradient from objects
                        if self.stop_egomotion_gradient:
                            current_seg = self.seg_stack[:, :, :, source_index * 3]  # [B, H, W]
                            background_mask = tf.equal(current_seg,
                                                       tf.constant(0, dtype=tf.uint8))  # [B, H, W]
                            background_mask = tf.tile(tf.expand_dims(background_mask, axis=3),
                                                      (1, 1, 1, 3))  # [B, H, W, 3]
                            background_mask = tf.to_float(background_mask)

                            background_mask_warped, _ = (
                                project.inverse_warp(
                                    background_mask,
                                    self.depth_upsampled[1][s],
                                    egomotion_mat_i_1,
                                    self.intrinsic_mat[:, 0, :, :],
                                    self.intrinsic_mat_inv[:, 0, :, :]))
                            # Stop gradient for mask
                            background_mask_warped = tf.stop_gradient(background_mask_warped)

                            background_warped, _ = (
                                project.inverse_warp(
                                    self.image_stack[:, :, :, source_index * 3:(source_index + 1) * 3],
                                    self.depth_upsampled[1][s],
                                    egomotion_mat_i_1,
                                    self.intrinsic_mat[:, 0, :, :],
                                    self.intrinsic_mat_inv[:, 0, :, :]))

                            obj_warped, _ = (
                                project.inverse_warp(
                                    self.image_stack[:, :, :, source_index * 3:(source_index + 1) * 3],
                                    self.depth_upsampled[1][s],
                                    tf.stop_gradient(egomotion_mat_i_1),  # stop gradient from objects
                                    self.intrinsic_mat[:, 0, :, :],
                                    self.intrinsic_mat_inv[:, 0, :, :]))

                            warped_image_i_1 = background_warped * background_mask_warped + \
                                               obj_warped * (1.0 - background_mask_warped)

                            background_rigid_flow = project.compute_rigid_flow(
                                self.depth_upsampled[1][s],
                                egomotion_mat_i_1,
                                self.intrinsic_mat[:, 0, :, :],
                                self.intrinsic_mat_inv[:, 0, :, :]
                            )  # [B, H, W, 2]

                            obj_rigid_flow = project.compute_rigid_flow(
                                self.depth_upsampled[1][s],
                                tf.stop_gradient(egomotion_mat_i_1),  # stop gradients for objects
                                self.intrinsic_mat[:, 0, :, :],
                                self.intrinsic_mat_inv[:, 0, :, :]
                            )

                            rigid_flow_i_1 = background_rigid_flow * background_mask[:, :, :, :2] + \
                                             obj_rigid_flow * (1.0 - background_mask[:, :, :, :2])
                        else:
                            warped_image_i_1, _ = (
                                project.inverse_warp(
                                    self.image_stack[:, :, :, source_index * 3:(source_index + 1) * 3],
                                    self.depth_upsampled[1][s],
                                    egomotion_mat_i_1,
                                    self.intrinsic_mat[:, 0, :, :],
                                    self.intrinsic_mat_inv[:, 0, :, :]))

                            rigid_flow_i_1 = project.compute_rigid_flow(
                                self.depth_upsampled[1][s],
                                egomotion_mat_i_1,
                                self.intrinsic_mat[:, 0, :, :],
                                self.intrinsic_mat_inv[:, 0, :, :])

                        self.warped_seq[s].append(warped_image_i_1)
                        self.rigid_flow_seq[s].append(rigid_flow_i_1)
                        self.egomotions_seq[s].append(egomotion_mat_i_1)

                    # Second, for every object in the segmentation mask, take its mask and
                    # warp it according to the egomotion estimate. Then put a threshold to
                    # binarize the warped result. Use this mask to mask out background and
                    # other objects, and pass the filtered image to the region deformer
                    # network.
                    self.object_transforms[s] = []
                    self.object_masks[s] = []
                    self.object_ids[s] = []
                    self.object_masks_warped[s] = []
                    self.inputs_region_deformer_net[s] = {}

                    for i in range(self.batch_size):
                        seg_sequence = self.seg_stack[i]  # (H, W, 9=3*3)
                        # Backgound is 0, include 0 here
                        object_ids = tf.unique(tf.reshape(seg_sequence, [-1]))[0]

                        self.object_ids[s].append(object_ids)
                        color_stack = []
                        mask_stack = []
                        mask_stack_warped = []
                        for j in range(self.seq_length):
                            current_image = self.warped_seq[s][j][i]  # (H, W, 3)
                            current_seg = seg_sequence[:, :, j * 3:(j + 1) * 3]  # (H, W, 3)

                            # When enforcing object depth prior, exclude objects when computing
                            # neighboring mask
                            background = tf.equal(current_seg[:, :, 0],
                                                  tf.constant(0, dtype=tf.uint8))  # [H, W]

                            def process_obj_mask_warp(obj_id):
                                """Performs warping of the individual object masks."""
                                obj_mask = tf.to_float(tf.equal(current_seg, obj_id))
                                # Warp obj_mask according to overall egomotion.
                                obj_mask_warped, _ = (
                                    project.inverse_warp(
                                        tf.expand_dims(obj_mask, axis=0),
                                        # Middle frame, highest scale, batch element i:
                                        tf.expand_dims(self.depth_upsampled[1][s][i], axis=0),
                                        # Matrix for warping j into middle frame, batch elem. i:
                                        tf.expand_dims(self.egomotions_seq[s][j][i], axis=0),
                                        tf.expand_dims(self.intrinsic_mat[i, 0, :, :], axis=0),
                                        tf.expand_dims(self.intrinsic_mat_inv[i, 0, :, :],
                                                       axis=0)))

                                obj_mask_warped = tf.squeeze(obj_mask_warped, axis=0)  # specify axis=0
                                obj_mask_binarized = tf.greater(  # Threshold to binarize mask.
                                    obj_mask_warped, tf.constant(0.5))
                                return tf.to_float(obj_mask_binarized)  # [H, W, 3]

                            def process_obj_mask(obj_id):
                                """Returns the individual object masks separately."""
                                return tf.to_float(tf.equal(current_seg, obj_id))

                            object_masks = tf.map_fn(  # (N, H, W, 3)
                                process_obj_mask, object_ids, dtype=tf.float32)

                            if self.object_depth_weight > 0:
                                # The inverse depth of a moving object should be larger or equal to
                                # its horizontal surrounding environment
                                depth_pred = self.depth_upsampled[j][s][i]  # [H, W, 1]

                                def get_obj_losses(obj_mask):
                                    # Note obj_mask includes background

                                    # Find width of segment
                                    coords = tf.where(tf.greater(
                                        obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32)
                                    ))  # [num_true, 2]
                                    y_max = tf.to_int32(tf.reduce_max(coords[:, 0]))
                                    y_min = tf.to_int32(tf.reduce_min(coords[:, 0]))
                                    x_max = tf.to_int32(tf.reduce_max(coords[:, 1]))
                                    x_min = tf.to_int32(tf.reduce_min(coords[:, 1]))

                                    neighbor_pixel = 10  # empirical value

                                    id_x_min = tf.maximum(0, x_min - neighbor_pixel)
                                    id_x_max = tf.minimum(self.img_width - 1, x_max + neighbor_pixel)

                                    slice1 = tf.zeros([y_min, self.img_width])

                                    slice2_1 = tf.zeros([y_max - y_min + 1, id_x_min])
                                    slice2_2 = tf.ones([y_max - y_min + 1,
                                                        (id_x_max - id_x_min + 1)])  # neighbor
                                    slice2_3 = tf.zeros([y_max - y_min + 1,
                                                         self.img_width - 1 - id_x_max])
                                    slice2 = tf.concat([slice2_1, slice2_2, slice2_3],
                                                       axis=1)  # [y_max - y_min, W]
                                    slice3 = tf.zeros([self.img_height - 1 - y_max, self.img_width])
                                    neighbor_mask = tf.concat([slice1, slice2, slice3],
                                                              axis=0)  # [H, W]

                                    neighbor_mask = neighbor_mask * (tf.to_float(
                                        tf.less(obj_mask[:, :, 0], tf.constant(0.5, dtype=tf.float32))
                                    ))

                                    # Handle overlapping objects
                                    if self.exclude_object_mask:
                                        neighbor_mask = neighbor_mask * tf.to_float(background)  # [H, W]

                                    neighbor_depth = tf.boolean_mask(
                                        depth_pred,
                                        tf.greater(
                                            tf.reshape(neighbor_mask,
                                                       (self.img_height, self.img_width, 1)),
                                            tf.constant(0.5, dtype=tf.float32)))
                                    reference_depth = tf.boolean_mask(
                                        depth_pred, tf.greater(
                                            tf.reshape(obj_mask[:, :, 0],
                                                       (self.img_height, self.img_width, 1)),
                                            tf.constant(0.5, dtype=tf.float32)))

                                    neighbor_mean = tf.reduce_mean(neighbor_depth)
                                    reference_mean = tf.reduce_mean(reference_depth)

                                    # Soft constraint
                                    loss = tf.maximum(reference_mean - neighbor_mean - self.object_depth_threshold,
                                                      tf.constant(0.0, dtype=tf.float32))
                                    return loss

                                losses = tf.map_fn(get_obj_losses, object_masks, dtype=tf.float32)
                                # Remove background, whose id is 0
                                self.object_depth_loss += tf.reduce_mean(tf.sign(tf.to_float(object_ids)) * losses)

                            object_masks_warped = tf.map_fn(  # (N, H, W, 3)
                                process_obj_mask_warp, object_ids, dtype=tf.float32)

                            # When warping object mask, stop gradient of depth and egomotion
                            if self.stop_egomotion_gradient:
                                object_masks_warped = tf.stop_gradient(object_masks_warped)

                            filtered_images = tf.map_fn(
                                lambda mask: current_image * mask, object_masks_warped,
                                dtype=tf.float32)  # (N, H, W, 3)
                            color_stack.append(filtered_images)
                            mask_stack.append(object_masks)
                            mask_stack_warped.append(object_masks_warped)

                        # For this batch-element, if there are N moving objects,
                        # color_stack, mask_stack and mask_stack_warped contain both
                        # seq_length elements of shape (N, H, W, 3).
                        # We can now concatenate them on the last axis, creating a tensor of
                        # (N, H, W, 3*3 = 9), and, assuming N does not get too large so that
                        # we have enough memory, pass them in a single batch to the region
                        # deformer network.
                        mask_stack = tf.concat(mask_stack, axis=3)  # (N, H, W, 9)
                        mask_stack_warped = tf.concat(mask_stack_warped, axis=3)
                        color_stack = tf.concat(color_stack, axis=3)  # (N, H, W, 9)

                        if self.stop_egomotion_gradient:
                            # Gradient has been stopped before
                            image_stack = color_stack
                        else:
                            image_stack = tf.stop_gradient(color_stack)

                        all_transforms = nets.region_deformer_net(
                            image_stack=image_stack,
                            disp_bottleneck_stack=None,
                            joint_encoder=False,  # joint encoder not supported.
                            seq_length=self.seq_length,
                            weight_reg=self.weight_reg,
                            trans_params_size=self.trans_params_size,
                            region_deformer_scaling=self.region_deformer_scaling)
                        # all_transforms of shape (N, 2, 32)

                        self.object_transforms[s].append(all_transforms)
                        self.object_masks[s].append(mask_stack)
                        self.object_masks_warped[s].append(mask_stack_warped)
                        self.inputs_region_deformer_net[s][i] = color_stack
                        tf.get_variable_scope().reuse_variables()
        else:
            # Don't handle motion, classic model formulation.
            with tf.name_scope('egomotion_prediction'):
                if self.joint_encoder:
                    # Re-arrange disp_bottleneck_stack to be of shape
                    # [B, h_hid, w_hid, c_hid * seq_length]. Currently, it is a list with
                    # seq_length elements, each of dimension [B, h_hid, w_hid, c_hid].
                    disp_bottleneck_stack = tf.concat(disp_bottlenecks, axis=3)
                else:
                    disp_bottleneck_stack = None
                self.egomotion = nets.egomotion_net(
                    image_stack=self.image_stack_norm,
                    disp_bottleneck_stack=disp_bottleneck_stack,
                    joint_encoder=self.joint_encoder,
                    seq_length=self.seq_length,
                    weight_reg=self.weight_reg,
                    same_trans_rot_scaling=self.same_trans_rot_scaling)
Esempio n. 13
0
    def build_train_graph(self):
        opt = self.opt
        loader = DataLoader(opt.dataset_dir, opt.batch_size, opt.img_height,
                            opt.img_width, opt.num_source, opt.num_scales)
        with tf.name_scope("data_loading"):
            tgt_image, src_image_stack, intrinsics = loader.load_train_batch()
            tgt_image = self.preprocess_image(tgt_image)
            src_image_stack = self.preprocess_image(src_image_stack)

        with tf.name_scope("depth_prediction"):
            pred_disp, depth_net_endpoints = disp_net(tgt_image,
                                                      is_training=True)
            pred_depth = [1. / d for d in pred_disp]

        with tf.name_scope("pose_and_explainability_prediction"):
            pred_poses, pred_exp_logits, pose_exp_net_endpoints = \
                pose_exp_net(tgt_image,
                             src_image_stack,
                             do_exp=(opt.explain_reg_weight > 0),
                             is_training=True)

        with tf.name_scope("compute_loss"):
            pixel_loss = 0
            exp_loss = 0
            smooth_loss = 0
            tgt_image_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            exp_mask_stack_all = []
            for s in range(opt.num_scales):
                if opt.explain_reg_weight > 0:
                    # Construct a reference explainability mask (i.e. all
                    # pixels are explainable)
                    ref_exp_mask = self.get_reference_explain_mask(s)
                # Scale the source and target images for computing loss at the
                # according scale.
                curr_tgt_image = tf.image.resize_area(tgt_image, [
                    int(opt.img_height / (2**s)),
                    int(opt.img_width / (2**s))
                ])
                curr_src_image_stack = tf.image.resize_area(
                    src_image_stack, [
                        int(opt.img_height / (2**s)),
                        int(opt.img_width / (2**s))
                    ])

                if opt.smooth_weight > 0:
                    smooth_loss += opt.smooth_weight/(2**s) * \
                        self.compute_smooth_loss(pred_disp[s])

                for i in range(opt.num_source):
                    # Inverse warp the source image to the target image frame
                    curr_proj_image = projective_inverse_warp(
                        curr_src_image_stack[:, :, :, 3 * i:3 * (i + 1)],
                        tf.squeeze(pred_depth[s], axis=3), pred_poses[:, i, :],
                        intrinsics[:, s, :, :])
                    curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image)
                    # Cross-entropy loss as regularization for the
                    # explainability prediction
                    if opt.explain_reg_weight > 0:
                        curr_exp_logits = tf.slice(pred_exp_logits[s],
                                                   [0, 0, 0, i * 2],
                                                   [-1, -1, -1, 2])
                        exp_loss += opt.explain_reg_weight * \
                            self.compute_exp_reg_loss(curr_exp_logits,
                                                      ref_exp_mask)
                        curr_exp = tf.nn.softmax(curr_exp_logits)
                    # Photo-consistency loss weighted by explainability
                    if opt.explain_reg_weight > 0:
                        pixel_loss += tf.reduce_mean(curr_proj_error * \
                            tf.expand_dims(curr_exp[:,:,:,1], -1))
                    else:
                        pixel_loss += tf.reduce_mean(curr_proj_error)
                    # Prepare images for tensorboard summaries
                    if i == 0:
                        proj_image_stack = curr_proj_image
                        proj_error_stack = curr_proj_error
                        if opt.explain_reg_weight > 0:
                            exp_mask_stack = tf.expand_dims(
                                curr_exp[:, :, :, 1], -1)
                    else:
                        proj_image_stack = tf.concat(
                            [proj_image_stack, curr_proj_image], axis=3)
                        proj_error_stack = tf.concat(
                            [proj_error_stack, curr_proj_error], axis=3)
                        if opt.explain_reg_weight > 0:
                            exp_mask_stack = tf.concat([
                                exp_mask_stack,
                                tf.expand_dims(curr_exp[:, :, :, 1], -1)
                            ],
                                                       axis=3)
                tgt_image_all.append(curr_tgt_image)
                src_image_stack_all.append(curr_src_image_stack)
                proj_image_stack_all.append(proj_image_stack)
                proj_error_stack_all.append(proj_error_stack)
                if opt.explain_reg_weight > 0:
                    exp_mask_stack_all.append(exp_mask_stack)
            total_loss = pixel_loss + smooth_loss + exp_loss

        with tf.name_scope("train_op"):
            train_vars = [var for var in tf.trainable_variables()]
            optim = tf.train.AdamOptimizer(opt.learning_rate, opt.beta1)
            # self.grads_and_vars = optim.compute_gradients(total_loss,
            #                                               var_list=train_vars)
            # self.train_op = optim.apply_gradients(self.grads_and_vars)
            self.train_op = slim.learning.create_train_op(total_loss, optim)
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
            self.incr_global_step = tf.assign(self.global_step,
                                              self.global_step + 1)

        # Collect tensors that are useful later (e.g. tf summary)
        self.pred_depth = pred_depth
        self.pred_poses = pred_poses
        self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pixel_loss = pixel_loss
        self.exp_loss = exp_loss
        self.smooth_loss = smooth_loss
        self.tgt_image_all = tgt_image_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        self.exp_mask_stack_all = exp_mask_stack_all