def build_test(self, build_type='both'):
        self.loader = DataLoader(trainable=False, **self.config)
        self.num_scales = self.loader.num_scales
        self.num_source = self.loader.num_source
        with tf.name_scope('data_loading'):
            self.tgt_image = tf.placeholder(tf.uint8, [self.loader.batch_size,
                    self.loader.img_height, self.loader.img_width, 3])
            tgt_image = tf.image.convert_image_dtype(self.tgt_image, dtype=tf.float32)
            tgt_image_net = self.preprocess_image(tgt_image)
            if build_type != 'depth':
                self.src_image_stack = tf.placeholder(tf.uint8, [self.loader.batch_size,
                    self.loader.img_height, self.loader.img_width, 3 * self.num_source])
                src_image_stack = tf.image.convert_image_dtype(self.src_image_stack, dtype=tf.float32)
                src_image_stack_net = self.preprocess_image(src_image_stack)
            #if self.preprocess:


            #else:
            #    tgt_image_net = tgt_image
            #    src_image_stack_net = src_image_stack

        with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(False, **self.config)

            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)
            pred_disp = net_builder.build_disp_net(res18_tc, skips_tc)
            pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [self.loader.img_height, self.loader.img_width]) for i in
                range(self.num_scales)]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth)

            self.pred_depth = pred_depth_rawscale[0]
            self.pred_disp = pred_disp_rawscale[0]

            if build_type != 'depth':
                num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3)
                assert num_source == 2

                if self.pose_type == 'seperate':
                    res18_ctp, _ = net_builder.build_resnet18(
                        tf.concat([tgt_image_net,src_image_stack_net[:, :, :, :3]], axis=3),
                        prefix='pose_'
                    )
                    res18_ctn, _ = net_builder.build_resnet18(
                        tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3),
                        prefix='pose_'
                    )
                elif self.pose_type == 'shared':
                    res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3])
                    res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:])
                    res18_ctp = tf.concat([res18_tc, res18_tp], axis=3)
                    res18_ctn = tf.concat([res18_tc, res18_tn], axis=3)
                else:
                    raise NotImplementedError

                pred_pose_ctp = net_builder.build_pose_net2(res18_ctp)
                pred_pose_ctn = net_builder.build_pose_net2(res18_ctn)

                pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1)

                self.pred_poses = pred_poses
    def build_train(self):
        self.ssim_ratio = np.float(self.config['model']['reproj_alpha'])
        self.smoothness_ratio = np.float(self.config['model']['smooth_alpha'])
        self.start_learning_rate = np.float(self.config['model']['learning_rate'])
        self.total_epoch = np.int(self.config['model']['epoch'])
        self.beta1 = np.float(self.config['model']['beta1'])
        self.continue_ckpt = self.config['model']['continue_ckpt']
        self.torch_res18_ckpt = self.config['model']['torch_res18_ckpt']
        self.summary_freq = self.config['model']['summary_freq']
        self.auto_mask = self.config['model']['auto_mask']

        loader = DataLoader(trainable=True, **self.config)
        self.num_scales = loader.num_scales
        self.num_source = loader.num_source
        with tf.name_scope('data_loading'):
            tgt_image, src_image_stack, tgt_image_aug, src_image_stack_aug, intrinsics = loader.load_batch()
            tgt_image = tf.image.convert_image_dtype(tgt_image, dtype=tf.float32)
            src_image_stack = tf.image.convert_image_dtype(src_image_stack, dtype=tf.float32)
            tgt_image_aug = tf.image.convert_image_dtype(tgt_image_aug, dtype=tf.float32)
            src_image_stack_aug = tf.image.convert_image_dtype(src_image_stack_aug, dtype=tf.float32)
            if self.preprocess:
                tgt_image_net = self.preprocess_image(tgt_image_aug)
                src_image_stack_net = self.preprocess_image(src_image_stack_aug)
            else:
                tgt_image_net = tgt_image_aug
                src_image_stack_net = src_image_stack_aug

        with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope:
            net_builder = Net(True, **self.config)
            num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3)
            assert num_source == 2
            res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net)

            if self.pose_type == 'seperate':
                res18_ctp, _ = net_builder.build_resnet18(
                    tf.concat([src_image_stack_net[:, :, :, :3], tgt_image_net], axis=3),
                    prefix='pose_'
                )
                res18_ctn, _ = net_builder.build_resnet18(
                    tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3),
                    prefix='pose_'
                )
            elif self.pose_type == 'shared':
                res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3])
                res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:])
                res18_ctp = tf.concat([res18_tp, res18_tc], axis=3)
                res18_ctn = tf.concat([res18_tc, res18_tn], axis=3)
            else:
                raise NotImplementedError

            pred_pose_ctp = net_builder.build_pose_net2(res18_ctp)
            pred_pose_ctn = net_builder.build_pose_net2(res18_ctn)

            pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1)

            # res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:,:,:,:3])
            # res18_tn, _= net_builder.build_resnet18(src_image_stack_net[:,:,:,3:])
            #
            # pred_poses = net_builder.build_pose_net(res18_tp, res18_tc, res18_tn)

            pred_disp = net_builder.build_disp_net(res18_tc,skips_tc)

            H = tgt_image.get_shape().as_list()[1]
            W = tgt_image.get_shape().as_list()[2]


            pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [loader.img_height, loader.img_width]) for i in range(self.num_scales)]
            pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth)

            tgt_image_pyramid = [tf.image.resize_nearest_neighbor(tgt_image, [np.int(H // (2 ** s)), np.int(W // (2 ** s))]) for s in range(self.num_scales)]

        with tf.name_scope('compute_loss'):
            tgt_image_stack_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            pixel_losses = 0.
            smooth_losses = 0.
            total_loss = 0.
            if self.auto_mask:
                # pred_auto_masks1 = []
                # pred_auto_masks2 = []
                pred_auto_masks = []
            for s in range(loader.num_scales):
                reprojection_losses = []
                for i in range(num_source):
                    curr_proj_image = projective_inverse_warp(src_image_stack[:,:,:, 3*i:3*(i+1)],
                                                              tf.squeeze(pred_depth_rawscale[s], axis=3),
                                                              pred_poses[:,i,:],
                                                              intrinsics=intrinsics[:,0,:,:], invert=True if i == 0 else False)
                    curr_proj_error = tf.abs(curr_proj_image - tgt_image)

                    reprojection_losses.append(self.compute_reprojection_loss(curr_proj_image, tgt_image))

                    if i == 0:
                        proj_image_stack = curr_proj_image
                        proj_error_stack = curr_proj_error
                    else:
                        proj_image_stack = tf.concat([proj_image_stack,curr_proj_image], axis=3)
                        proj_error_stack = tf.concat([proj_error_stack,curr_proj_error], axis=3)

                reprojection_losses = tf.concat(reprojection_losses, axis=3)

                combined = reprojection_losses
                if self.auto_mask:
                    identity_reprojection_losses = []
                    for i in range(num_source):
                        identity_reprojection_losses.append(self.compute_reprojection_loss(src_image_stack[:,:,:, 3*i:3*(i+1)], tgt_image))
                    identity_reprojection_losses = tf.concat(identity_reprojection_losses, axis=3)

                    identity_reprojection_losses += (tf.random_normal(identity_reprojection_losses.get_shape()) * 1e-5)

                    combined = tf.concat([identity_reprojection_losses, reprojection_losses], axis=3)
                    pred_auto_masks.append(tf.expand_dims(tf.cast(tf.argmin(combined, axis=3) > 1,tf.float32) * 255, -1))

                    # pred_auto_masks1.append(tf.expand_dims(tf.cast(tf.argmin(tf.concat([combined[:,:,:,1:2],combined[:,:,:,3:4]],axis=3), axis=3), tf.float32) * 255,-1))
                    # pred_auto_masks2.append(tf.expand_dims(tf.cast(
                    #     tf.argmin(combined, axis=3) > 1,
                    #     tf.float32) * 255, -1))

                reprojection_loss = tf.reduce_mean(tf.reduce_min(combined, axis=3))

                pixel_losses += reprojection_loss

                smooth_loss = self.get_smooth_loss(pred_disp[s], tgt_image_pyramid[s])
                smooth_losses += smooth_loss
                smooth_loss /= (2 ** s)

                scale_total_loss = reprojection_loss + self.smoothness_ratio * smooth_loss
                total_loss += scale_total_loss

                tgt_image_stack_all.append(tgt_image)
                src_image_stack_all.append(src_image_stack_aug)
                proj_image_stack_all.append(proj_image_stack)
                proj_error_stack_all.append(proj_error_stack)

            total_loss /= loader.num_scales
            pixel_losses /= loader.num_scales
            smooth_losses /= loader.num_scales

        with tf.name_scope('train_op'):
            self.total_step = self.total_epoch * loader.steps_per_epoch
            self.global_step = tf.Variable(0,name='global_step',trainable=False)
            learning_rates = [self.start_learning_rate, self.start_learning_rate / 10 ]
            boundaries = [np.int(self.total_step * 3 / 4)]
            self.learning_rate = tf.train.piecewise_constant(self.global_step, boundaries, learning_rates)
            optimizer = tf.train.AdamOptimizer(self.learning_rate, self.beta1)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = optimizer.minimize(total_loss, global_step=self.global_step)

            self.incr_global_step = tf.assign(self.global_step,self.global_step + 1)

        # Collect tensors that are useful later (e.g. tf summary)
        self.pred_depth = pred_depth_rawscale
        self.pred_disp = pred_disp
        self.pred_poses = pred_poses
        self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pixel_loss = pixel_losses
        self.smooth_loss = smooth_losses
        self.tgt_image_all = tgt_image_stack_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        if self.auto_mask:
            self.pred_auto_masks = pred_auto_masks