Ejemplo n.º 1
0
    def build_test(self, build_type='both'):
        self.loader = DataLoader(trainable=False, **self.config)
        self.num_scales = self.loader.num_scales
        self.num_source = self.loader.num_source
        with tf.name_scope('data_loading'):
            self.tgt_image = tf.placeholder(tf.uint8, [self.loader.batch_size,
                    self.loader.img_height, self.loader.img_width, 3])
            tgt_image = tf.image.convert_image_dtype(self.tgt_image, dtype=tf.float32)
            tgt_image_net = self.preprocess_image(tgt_image)
            if build_type != 'depth':
                self.src_image_stack = tf.placeholder(tf.uint8, [self.loader.batch_size,
                    self.loader.img_height, self.loader.img_width, 3 * self.num_source])
                src_image_stack = tf.image.convert_image_dtype(self.src_image_stack, dtype=tf.float32)
                src_image_stack_net = self.preprocess_image(src_image_stack)
            #if self.preprocess:


            #else:
            #    tgt_image_net = tgt_image
            #    src_image_stack_net = src_image_stack

        with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(False, **self.config)

            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)
            pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [self.loader.img_height, self.loader.img_width]) for i in
                range(self.num_scales)]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth)

            self.pred_depth = pred_depth_rawscale[0]
            self.pred_disp = pred_disp_rawscale[0]

            if build_type != 'depth':
                num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3)
                assert num_source == 2

                if self.pose_type == 'seperate':
                    res18_ctp, _ = net_builder.build_resnet18(
                        tf.concat([tgt_image_net,src_image_stack_net[:, :, :, :3]], axis=3),
                        prefix='pose_'
                    )
                    res18_ctn, _ = net_builder.build_resnet18(
                        tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3),
                        prefix='pose_'
                    )
                elif self.pose_type == 'shared':
                    res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3])
                    res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:])
                    res18_ctp = tf.concat([res18_tc, res18_tp], axis=3)
                    res18_ctn = tf.concat([res18_tc, res18_tn], axis=3)
                else:
                    raise NotImplementedError

                pred_pose_ctp = net_builder.build_pose_net2(res18_ctp)
                pred_pose_ctn = net_builder.build_pose_net2(res18_ctn)

                pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1)

                self.pred_poses = pred_poses
    def build_train(self):
        self.ssim_ratio = np.float(self.config['model']['reproj_alpha'])
        self.smoothness_ratio = np.float(self.config['model']['smooth_alpha'])
        self.start_learning_rate = np.float(self.config['model']['learning_rate'])
        self.total_epoch = np.int(self.config['model']['epoch'])
        self.beta1 = np.float(self.config['model']['beta1'])
        self.continue_ckpt = self.config['model']['continue_ckpt']
        self.torch_res18_ckpt = self.config['model']['torch_res18_ckpt']
        self.summary_freq = self.config['model']['summary_freq']
        self.auto_mask = self.config['model']['auto_mask']

        loader = DataLoader(trainable=True, **self.config)
        self.num_scales = loader.num_scales
        self.num_source = loader.num_source
        with tf.name_scope('data_loading'):
            tgt_image, src_image_stack, tgt_image_aug, src_image_stack_aug, intrinsics = loader.load_batch()
            tgt_image = tf.image.convert_image_dtype(tgt_image, dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(src_image_stack, dtype=tf.float32)
            tgt_image_aug = tf.image.convert_image_dtype(tgt_image_aug, dtype=tf.float32)
            src_image_stack_aug = tf.image.convert_image_dtype(src_image_stack_aug, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image_aug)
                src_image_stack_net = self.preprocess_image(src_image_stack_aug)
            else:
                tgt_image_net = tgt_image_aug
                src_image_stack_net = src_image_stack_aug

        with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(True, **self.config)
            num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)

            if self.pose_type == 'seperate':
                res18_ctp, _ = net_builder.build_resnet18(
                    tf.concat([src_image_stack_net[:, :, :, :3], tgt_image_net], axis=3),
                    prefix='pose_'
                )
                res18_ctn, _ = net_builder.build_resnet18(
                    tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3),
                    prefix='pose_'
                )
            elif self.pose_type == 'shared':
                res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3])
                res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:])
                res18_ctp = tf.concat([res18_tp, res18_tc], axis=3)
                res18_ctn = tf.concat([res18_tc, res18_tn], axis=3)
            else:
                raise NotImplementedError

            pred_pose_ctp = net_builder.build_pose_net2(res18_ctp)
            pred_pose_ctn = net_builder.build_pose_net2(res18_ctn)

            pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1)

            # res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:,:,:,:3])
            # res18_tn, _= net_builder.build_resnet18(src_image_stack_net[:,:,:,3:])
            #
            # pred_poses = net_builder.build_pose_net(res18_tp, res18_tc, res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc,skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]


            pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [loader.img_height, loader.img_width]) for i in range(self.num_scales)]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth)

            tgt_image_pyramid = [tf.image.resize_nearest_neighbor(tgt_image, [np.int(H // (2 ** s)), np.int(W // (2 ** s))]) for s in range(self.num_scales)]

        with tf.name_scope('compute_loss'):
            tgt_image_stack_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            pixel_losses = 0.
            smooth_losses = 0.
            total_loss = 0.
            if self.auto_mask:
                # pred_auto_masks1 = []
                # pred_auto_masks2 = []
                pred_auto_masks = []
            for s in range(loader.num_scales):
                reprojection_losses = []
                for i in range(num_source):
                    curr_proj_image = projective_inverse_warp(src_image_stack[:,:,:, 3*i:3*(i+1)],
                                                              tf.squeeze(pred_depth_rawscale[s], axis=3),
                                                              pred_poses[:,i,:],
                                                              intrinsics=intrinsics[:,0,:,:], invert=True if i == 0 else False)
                    curr_proj_error = tf.abs(curr_proj_image - tgt_image)

                    reprojection_losses.append(self.compute_reprojection_loss(curr_proj_image, tgt_image))

                    if i == 0:
                        proj_image_stack = curr_proj_image
                        proj_error_stack = curr_proj_error
                    else:
                        proj_image_stack = tf.concat([proj_image_stack,curr_proj_image], axis=3)
                        proj_error_stack = tf.concat([proj_error_stack,curr_proj_error], axis=3)

                reprojection_losses = tf.concat(reprojection_losses, axis=3)

                combined = reprojection_losses
                if self.auto_mask:
                    identity_reprojection_losses = []
                    for i in range(num_source):
                        identity_reprojection_losses.append(self.compute_reprojection_loss(src_image_stack[:,:,:, 3*i:3*(i+1)], tgt_image))
                    identity_reprojection_losses = tf.concat(identity_reprojection_losses, axis=3)

                    identity_reprojection_losses += (tf.random_normal(identity_reprojection_losses.get_shape()) * 1e-5)

                    combined = tf.concat([identity_reprojection_losses, reprojection_losses], axis=3)
                    pred_auto_masks.append(tf.expand_dims(tf.cast(tf.argmin(combined, axis=3) > 1,tf.float32) * 255, -1))

                    # pred_auto_masks1.append(tf.expand_dims(tf.cast(tf.argmin(tf.concat([combined[:,:,:,1:2],combined[:,:,:,3:4]],axis=3), axis=3), tf.float32) * 255,-1))
                    # pred_auto_masks2.append(tf.expand_dims(tf.cast(
                    #     tf.argmin(combined, axis=3) > 1,
                    #     tf.float32) * 255, -1))

                reprojection_loss = tf.reduce_mean(tf.reduce_min(combined, axis=3))

                pixel_losses += reprojection_loss

                smooth_loss = self.get_smooth_loss(pred_disp[s], tgt_image_pyramid[s])
                smooth_losses += smooth_loss
                smooth_loss /= (2 ** s)

                scale_total_loss = reprojection_loss + self.smoothness_ratio * smooth_loss
                total_loss += scale_total_loss

                tgt_image_stack_all.append(tgt_image)
                src_image_stack_all.append(src_image_stack_aug)
                proj_image_stack_all.append(proj_image_stack)
                proj_error_stack_all.append(proj_error_stack)

            total_loss /= loader.num_scales
            pixel_losses /= loader.num_scales
            smooth_losses /= loader.num_scales

        with tf.name_scope('train_op'):
            self.total_step = self.total_epoch * loader.steps_per_epoch
            self.global_step = tf.Variable(0,name='global_step',trainable=False)
            learning_rates = [self.start_learning_rate, self.start_learning_rate / 10 ]
            boundaries = [np.int(self.total_step * 3 / 4)]
            self.learning_rate = tf.train.piecewise_constant(self.global_step, boundaries, learning_rates)
            optimizer = tf.train.AdamOptimizer(self.learning_rate, self.beta1)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = optimizer.minimize(total_loss, global_step=self.global_step)

            self.incr_global_step = tf.assign(self.global_step,self.global_step + 1)

        # Collect tensors that are useful later (e.g. tf summary)
        self.pred_depth = pred_depth_rawscale
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
        self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pixel_loss = pixel_losses
        self.smooth_loss = smooth_losses
        self.tgt_image_all = tgt_image_stack_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        if self.auto_mask:
            self.pred_auto_masks = pred_auto_masks
Ejemplo n.º 3
0
    def build_test(self):
        self.loader = DataLoader(trainable=False, **self.config)
        self.num_scales = self.loader.num_scales
        self.num_source = self.loader.num_source
        with tf.name_scope('data_loading'):
            self.tgt_image = tf.placeholder(tf.uint8, [
                self.loader.batch_size, self.loader.img_height,
                self.loader.img_width, 3
            ])

            self.src_image_stack = tf.placeholder(tf.uint8, [
                self.loader.batch_size, self.loader.img_height,
                self.loader.img_width, 3 * self.num_source
            ])

            tgt_image = tf.image.convert_image_dtype(self.tgt_image,
                                                     dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(
                self.src_image_stack, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image)
                src_image_stack_net = self.preprocess_image(src_image_stack)
            else:
                tgt_image_net = tgt_image
                src_image_stack_net = src_image_stack

        with tf.variable_scope('monodepth2_model',
                               reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(False, **self.config)
            num_source = np.int(
                src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            res18_tp, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, :3])
            res18_tn, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, 3:])

            pred_poses = net_builder.build_pose_net(res18_tp, res18_tc,
                                                    res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]

            pred_disp_rawscale = [
                tf.image.resize_bilinear(
                    pred_disp[i],
                    [self.loader.img_height, self.loader.img_width])
                for i in range(self.num_scales)
            ]

            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale,
                                                self.min_depth, self.max_depth)

            #pred_depth_rawscale = colorize(pred_depth_rawscale, cmap='magma')

        # Collect tensors that are useful later (e.g. tf summary)

        self.pred_depth = pred_disp_rawscale[0]
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
Ejemplo n.º 4
0
    def build_train(self):
        loader = DataLoader(trainable=True, **self.config)
        self.num_scales = loader.num_scales
        self.num_source = loader.num_source
        with tf.name_scope('data_loading'):
            tgt_image, src_image_stack, tgt_image_aug, src_image_stack_aug, intrinsics = loader.load_batch(
            )
            tgt_image = tf.image.convert_image_dtype(tgt_image,
                                                     dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(src_image_stack,
                                                           dtype=tf.float32)
            tgt_image_aug = tf.image.convert_image_dtype(tgt_image_aug,
                                                         dtype=tf.float32)
            src_image_stack_aug = tf.image.convert_image_dtype(
                src_image_stack_aug, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image_aug)
                src_image_stack_net = self.preprocess_image(
                    src_image_stack_aug)
            else:
                tgt_image_net = tgt_image_aug
                src_image_stack_net = src_image_stack_aug

        with tf.variable_scope('monodepth2_model',
                               reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(True, **self.config)
            num_source = np.int(
                src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            res18_tp, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, :3])
            res18_tn, _ = net_builder.build_resnet18(
                src_image_stack_net[:, :, :, 3:])

            pred_poses = net_builder.build_pose_net(res18_tp, res18_tc,
                                                    res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]

            pred_disp_rawscale = [
                tf.image.resize_bilinear(pred_disp[i],
                                         [loader.img_height, loader.img_width])
                for i in range(self.num_scales)
            ]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale,
                                                self.min_depth, self.max_depth)

            tgt_image_pyramid = [
                tf.image.resize_nearest_neighbor(
                    tgt_image,
                    [np.int(H //
                            (2**s)), np.int(W // (2**s))])
                for s in range(self.num_scales)
            ]

        with tf.name_scope('compute_loss'):
            tgt_image_stack_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            pixel_losses = 0.
            smooth_losses = 0.
            total_loss = 0.
            if self.auto_mask:
                pred_auto_masks1 = []
                pred_auto_masks2 = []
            for s in range(loader.num_scales):
                reprojection_losses = []
                for i in range(num_source):
                    curr_proj_image = projective_inverse_warp(
                        src_image_stack[:, :, :, 3 * i:3 * (i + 1)],
                        tf.squeeze(pred_depth_rawscale[s], axis=3),
                        pred_poses[:, i, :],
                        intrinsics=intrinsics[:, 0, :, :])
                    curr_proj_error = tf.abs(curr_proj_image - tgt_image)

                    #if self.auto_mask and s == 0:
                    #proj_image_scale0.append(curr_proj_image)
                    #    proj_error_scale0.append(curr_proj_error)

                    reprojection_losses.append(
                        self.compute_reprojection_loss(curr_proj_image,
                                                       tgt_image))

                    if i == 0:
                        proj_image_stack = curr_proj_image
                        proj_error_stack = curr_proj_error
                    else:
                        proj_image_stack = tf.concat(
                            [proj_image_stack, curr_proj_image], axis=3)
                        proj_error_stack = tf.concat(
                            [proj_error_stack, curr_proj_error], axis=3)

                reprojection_losses = tf.concat(reprojection_losses, axis=3)

                combined = reprojection_losses
                if self.auto_mask:
                    identity_reprojection_losses = []
                    for i in range(num_source):
                        identity_reprojection_losses.append(
                            self.compute_reprojection_loss(
                                src_image_stack[:, :, :, 3 * i:3 * (i + 1)],
                                tgt_image))
                    identity_reprojection_losses = tf.concat(
                        identity_reprojection_losses, axis=3)

                    identity_reprojection_losses += (tf.random_normal(
                        identity_reprojection_losses.get_shape()) * 1e-5)

                    combined = tf.concat(
                        [identity_reprojection_losses, reprojection_losses],
                        axis=3)

                    pred_auto_masks1.append(
                        tf.expand_dims(
                            tf.cast(
                                tf.argmin(tf.concat([
                                    combined[:, :, :, 1:2], combined[:, :, :,
                                                                     3:4]
                                ],
                                                    axis=3),
                                          axis=3), tf.float32) * 255, -1))
                    pred_auto_masks2.append(
                        tf.expand_dims(
                            tf.cast(
                                tf.argmin(combined, axis=3) > 1, tf.float32) *
                            255, -1))

                reprojection_loss = tf.reduce_mean(
                    tf.reduce_min(combined, axis=3))

                pixel_losses += reprojection_loss

                smooth_loss = self.get_smooth_loss(
                    pred_disp[s], tgt_image_pyramid[s]) / (2**s)
                smooth_losses += smooth_loss
                scale_total_loss = reprojection_loss + self.smoothness_ratio * smooth_loss
                total_loss += scale_total_loss

                tgt_image_stack_all.append(tgt_image)
                src_image_stack_all.append(src_image_stack_aug)
                proj_image_stack_all.append(proj_image_stack)
                proj_error_stack_all.append(proj_error_stack)

            total_loss /= loader.num_scales
            pixel_losses /= loader.num_scales
            smooth_losses /= loader.num_scales

        with tf.name_scope('train_op'):
            self.total_step = self.total_epoch * loader.steps_per_epoch
            self.global_step = tf.Variable(0,
                                           name='global_step',
                                           trainable=False)
            learning_rates = [
                self.start_learning_rate, self.start_learning_rate / 10
            ]
            boundaries = [np.int(self.total_step * 3 / 4)]
            learning_rate = tf.train.piecewise_constant(
                self.global_step, boundaries, learning_rates)
            optim = tf.train.AdamOptimizer(learning_rate, self.beta1)
            self.train_op = slim.learning.create_train_op(total_loss, optim)
            self.incr_global_step = tf.assign(self.global_step,
                                              self.global_step + 1)

        # Collect tensors that are useful later (e.g. tf summary)
        self.pred_depth = pred_depth_rawscale
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
        self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pixel_loss = pixel_losses
        self.smooth_loss = smooth_losses
        self.tgt_image_all = tgt_image_stack_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        if self.auto_mask:
            self.pred_auto_masks1 = pred_auto_masks1
            self.pred_auto_masks2 = pred_auto_masks2