Exemplo n.º 1
0
    def build_test(self):
        self.loader = DataLoader(trainable=False, **self.config)
        self.num_scales = self.loader.num_scales
        self.num_source = self.loader.num_source
        with tf.name_scope('data_loading'):
            self.tgt_image = tf.placeholder(tf.uint8, [
                self.loader.batch_size, self.loader.img_height,
                self.loader.img_width, 3
            ])

            self.src_image_stack = tf.placeholder(tf.uint8, [
                self.loader.batch_size, self.loader.img_height,
                self.loader.img_width, 3 * self.num_source
            ])

            tgt_image = tf.image.convert_image_dtype(self.tgt_image,
                                                     dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(
                self.src_image_stack, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image)
                src_image_stack_net = self.preprocess_image(src_image_stack)
            else:
                tgt_image_net = tgt_image
                src_image_stack_net = src_image_stack

        with tf.variable_scope('monodepth2_model',
                               reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(False, **self.config)
            num_source = np.int(
                src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            res18_tp, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, :3])
            res18_tn, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, 3:])

            pred_poses = net_builder.build_pose_net(res18_tp, res18_tc,
                                                    res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]

            pred_disp_rawscale = [
                tf.image.resize_bilinear(
                    pred_disp[i],
                    [self.loader.img_height, self.loader.img_width])
                for i in range(self.num_scales)
            ]

            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale,
                                                self.min_depth, self.max_depth)

            #pred_depth_rawscale = colorize(pred_depth_rawscale, cmap='magma')

        # Collect tensors that are useful later (e.g. tf summary)

        self.pred_depth = pred_disp_rawscale[0]
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
Exemplo n.º 2
0
    def build_train(self):
        loader = DataLoader(trainable=True, **self.config)
        self.num_scales = loader.num_scales
        self.num_source = loader.num_source
        with tf.name_scope('data_loading'):
            tgt_image, src_image_stack, tgt_image_aug, src_image_stack_aug, intrinsics = loader.load_batch(
            )
            tgt_image = tf.image.convert_image_dtype(tgt_image,
                                                     dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(src_image_stack,
                                                           dtype=tf.float32)
            tgt_image_aug = tf.image.convert_image_dtype(tgt_image_aug,
                                                         dtype=tf.float32)
            src_image_stack_aug = tf.image.convert_image_dtype(
                src_image_stack_aug, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image_aug)
                src_image_stack_net = self.preprocess_image(
                    src_image_stack_aug)
            else:
                tgt_image_net = tgt_image_aug
                src_image_stack_net = src_image_stack_aug

        with tf.variable_scope('monodepth2_model',
                               reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(True, **self.config)
            num_source = np.int(
                src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            res18_tp, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, :3])
            res18_tn, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, 3:])

            pred_poses = net_builder.build_pose_net(res18_tp, res18_tc,
                                                    res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]

            pred_disp_rawscale = [
                tf.image.resize_bilinear(pred_disp[i],
                                         [loader.img_height, loader.img_width])
                for i in range(self.num_scales)
            ]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale,
                                                self.min_depth, self.max_depth)

            tgt_image_pyramid = [
                tf.image.resize_nearest_neighbor(
                    tgt_image,
                    [np.int(H //
                            (2**s)), np.int(W // (2**s))])
                for s in range(self.num_scales)
            ]

        with tf.name_scope('compute_loss'):
            tgt_image_stack_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            pixel_losses = 0.
            smooth_losses = 0.
            total_loss = 0.
            if self.auto_mask:
                pred_auto_masks1 = []
                pred_auto_masks2 = []
            for s in range(loader.num_scales):
                reprojection_losses = []
                for i in range(num_source):
                    curr_proj_image = projective_inverse_warp(
                        src_image_stack[:, :, :, 3 * i:3 * (i + 1)],
                        tf.squeeze(pred_depth_rawscale[s], axis=3),
                        pred_poses[:, i, :],
                        intrinsics=intrinsics[:, 0, :, :])
                    curr_proj_error = tf.abs(curr_proj_image - tgt_image)

                    #if self.auto_mask and s == 0:
                    #proj_image_scale0.append(curr_proj_image)
                    #    proj_error_scale0.append(curr_proj_error)

                    reprojection_losses.append(
                        self.compute_reprojection_loss(curr_proj_image,
                                                       tgt_image))

                    if i == 0:
                        proj_image_stack = curr_proj_image
                        proj_error_stack = curr_proj_error
                    else:
                        proj_image_stack = tf.concat(
                            [proj_image_stack, curr_proj_image], axis=3)
                        proj_error_stack = tf.concat(
                            [proj_error_stack, curr_proj_error], axis=3)

                reprojection_losses = tf.concat(reprojection_losses, axis=3)

                combined = reprojection_losses
                if self.auto_mask:
                    identity_reprojection_losses = []
                    for i in range(num_source):
                        identity_reprojection_losses.append(
                            self.compute_reprojection_loss(
                                src_image_stack[:, :, :, 3 * i:3 * (i + 1)],
                                tgt_image))
                    identity_reprojection_losses = tf.concat(
                        identity_reprojection_losses, axis=3)

                    identity_reprojection_losses += (tf.random_normal(
                        identity_reprojection_losses.get_shape()) * 1e-5)

                    combined = tf.concat(
                        [identity_reprojection_losses, reprojection_losses],
                        axis=3)

                    pred_auto_masks1.append(
                        tf.expand_dims(
                            tf.cast(
                                tf.argmin(tf.concat([
                                    combined[:, :, :, 1:2], combined[:, :, :,
                                                                     3:4]
                                ],
                                                    axis=3),
                                          axis=3), tf.float32) * 255, -1))
                    pred_auto_masks2.append(
                        tf.expand_dims(
                            tf.cast(
                                tf.argmin(combined, axis=3) > 1, tf.float32) *
                            255, -1))

                reprojection_loss = tf.reduce_mean(
                    tf.reduce_min(combined, axis=3))

                pixel_losses += reprojection_loss

                smooth_loss = self.get_smooth_loss(
                    pred_disp[s], tgt_image_pyramid[s]) / (2**s)
                smooth_losses += smooth_loss
                scale_total_loss = reprojection_loss + self.smoothness_ratio * smooth_loss
                total_loss += scale_total_loss

                tgt_image_stack_all.append(tgt_image)
                src_image_stack_all.append(src_image_stack_aug)
                proj_image_stack_all.append(proj_image_stack)
                proj_error_stack_all.append(proj_error_stack)

            total_loss /= loader.num_scales
            pixel_losses /= loader.num_scales
            smooth_losses /= loader.num_scales

        with tf.name_scope('train_op'):
            self.total_step = self.total_epoch * loader.steps_per_epoch
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
            learning_rates = [
                self.start_learning_rate, self.start_learning_rate / 10
            ]
            boundaries = [np.int(self.total_step * 3 / 4)]
            learning_rate = tf.train.piecewise_constant(
                self.global_step, boundaries, learning_rates)
            optim = tf.train.AdamOptimizer(learning_rate, self.beta1)
            self.train_op = slim.learning.create_train_op(total_loss, optim)
            self.incr_global_step = tf.assign(self.global_step,
                                              self.global_step + 1)

        # Collect tensors that are useful later (e.g. tf summary)
        self.pred_depth = pred_depth_rawscale
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
        self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pixel_loss = pixel_losses
        self.smooth_loss = smooth_losses
        self.tgt_image_all = tgt_image_stack_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        if self.auto_mask:
            self.pred_auto_masks1 = pred_auto_masks1
            self.pred_auto_masks2 = pred_auto_masks2