def build_test(self, build_type='both'): self.loader = DataLoader(trainable=False, **self.config) self.num_scales = self.loader.num_scales self.num_source = self.loader.num_source with tf.name_scope('data_loading'): self.tgt_image = tf.placeholder(tf.uint8, [self.loader.batch_size, self.loader.img_height, self.loader.img_width, 3]) tgt_image = tf.image.convert_image_dtype(self.tgt_image, dtype=tf.float32) tgt_image_net = self.preprocess_image(tgt_image) if build_type != 'depth': self.src_image_stack = tf.placeholder(tf.uint8, [self.loader.batch_size, self.loader.img_height, self.loader.img_width, 3 * self.num_source]) src_image_stack = tf.image.convert_image_dtype(self.src_image_stack, dtype=tf.float32) src_image_stack_net = self.preprocess_image(src_image_stack) #if self.preprocess: #else: # tgt_image_net = tgt_image # src_image_stack_net = src_image_stack with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope: net_builder = Net(False, **self.config) res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net) pred_disp = net_builder.build_disp_net(res18_tc, skips_tc) pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [self.loader.img_height, self.loader.img_width]) for i in range(self.num_scales)] pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth) self.pred_depth = pred_depth_rawscale[0] self.pred_disp = pred_disp_rawscale[0] if build_type != 'depth': num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3) assert num_source == 2 if self.pose_type == 'seperate': res18_ctp, _ = net_builder.build_resnet18( tf.concat([tgt_image_net,src_image_stack_net[:, :, :, :3]], axis=3), prefix='pose_' ) res18_ctn, _ = net_builder.build_resnet18( tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3), prefix='pose_' ) elif self.pose_type == 'shared': res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3]) res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:]) res18_ctp = tf.concat([res18_tc, res18_tp], axis=3) res18_ctn = tf.concat([res18_tc, res18_tn], axis=3) else: raise NotImplementedError pred_pose_ctp = net_builder.build_pose_net2(res18_ctp) pred_pose_ctn = net_builder.build_pose_net2(res18_ctn) pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1) self.pred_poses = pred_poses
def build_train(self): self.ssim_ratio = np.float(self.config['model']['reproj_alpha']) self.smoothness_ratio = np.float(self.config['model']['smooth_alpha']) self.start_learning_rate = np.float(self.config['model']['learning_rate']) self.total_epoch = np.int(self.config['model']['epoch']) self.beta1 = np.float(self.config['model']['beta1']) self.continue_ckpt = self.config['model']['continue_ckpt'] self.torch_res18_ckpt = self.config['model']['torch_res18_ckpt'] self.summary_freq = self.config['model']['summary_freq'] self.auto_mask = self.config['model']['auto_mask'] loader = DataLoader(trainable=True, **self.config) self.num_scales = loader.num_scales self.num_source = loader.num_source with tf.name_scope('data_loading'): tgt_image, src_image_stack, tgt_image_aug, src_image_stack_aug, intrinsics = loader.load_batch() tgt_image = tf.image.convert_image_dtype(tgt_image, dtype=tf.float32) src_image_stack = tf.image.convert_image_dtype(src_image_stack, dtype=tf.float32) tgt_image_aug = tf.image.convert_image_dtype(tgt_image_aug, dtype=tf.float32) src_image_stack_aug = tf.image.convert_image_dtype(src_image_stack_aug, dtype=tf.float32) if self.preprocess: tgt_image_net = self.preprocess_image(tgt_image_aug) src_image_stack_net = self.preprocess_image(src_image_stack_aug) else: tgt_image_net = tgt_image_aug src_image_stack_net = src_image_stack_aug with tf.variable_scope('monodepth2_model', reuse=tf.AUTO_REUSE) as scope: net_builder = Net(True, **self.config) num_source = np.int(src_image_stack_net.get_shape().as_list()[-1] // 3) assert num_source == 2 res18_tc, skips_tc = net_builder.build_resnet18(tgt_image_net) if self.pose_type == 'seperate': res18_ctp, _ = net_builder.build_resnet18( tf.concat([src_image_stack_net[:, :, :, :3], tgt_image_net], axis=3), prefix='pose_' ) res18_ctn, _ = net_builder.build_resnet18( tf.concat([tgt_image_net, src_image_stack_net[:, :, :, 3:]], axis=3), prefix='pose_' ) elif self.pose_type == 'shared': res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, :3]) res18_tn, _ = net_builder.build_resnet18(src_image_stack_net[:, :, :, 3:]) res18_ctp = tf.concat([res18_tp, res18_tc], axis=3) res18_ctn = tf.concat([res18_tc, res18_tn], axis=3) else: raise NotImplementedError pred_pose_ctp = net_builder.build_pose_net2(res18_ctp) pred_pose_ctn = net_builder.build_pose_net2(res18_ctn) pred_poses = tf.concat([pred_pose_ctp, pred_pose_ctn], axis=1) # res18_tp, _ = net_builder.build_resnet18(src_image_stack_net[:,:,:,:3]) # res18_tn, _= net_builder.build_resnet18(src_image_stack_net[:,:,:,3:]) # # pred_poses = net_builder.build_pose_net(res18_tp, res18_tc, res18_tn) pred_disp = net_builder.build_disp_net(res18_tc,skips_tc) H = tgt_image.get_shape().as_list()[1] W = tgt_image.get_shape().as_list()[2] pred_disp_rawscale = [tf.image.resize_bilinear(pred_disp[i], [loader.img_height, loader.img_width]) for i in range(self.num_scales)] pred_depth_rawscale = disp_to_depth(pred_disp_rawscale, self.min_depth, self.max_depth) tgt_image_pyramid = [tf.image.resize_nearest_neighbor(tgt_image, [np.int(H // (2 ** s)), np.int(W // (2 ** s))]) for s in range(self.num_scales)] with tf.name_scope('compute_loss'): tgt_image_stack_all = [] src_image_stack_all = [] proj_image_stack_all = [] proj_error_stack_all = [] pixel_losses = 0. smooth_losses = 0. total_loss = 0. if self.auto_mask: # pred_auto_masks1 = [] # pred_auto_masks2 = [] pred_auto_masks = [] for s in range(loader.num_scales): reprojection_losses = [] for i in range(num_source): curr_proj_image = projective_inverse_warp(src_image_stack[:,:,:, 3*i:3*(i+1)], tf.squeeze(pred_depth_rawscale[s], axis=3), pred_poses[:,i,:], intrinsics=intrinsics[:,0,:,:], invert=True if i == 0 else False) curr_proj_error = tf.abs(curr_proj_image - tgt_image) reprojection_losses.append(self.compute_reprojection_loss(curr_proj_image, tgt_image)) if i == 0: proj_image_stack = curr_proj_image proj_error_stack = curr_proj_error else: proj_image_stack = tf.concat([proj_image_stack,curr_proj_image], axis=3) proj_error_stack = tf.concat([proj_error_stack,curr_proj_error], axis=3) reprojection_losses = tf.concat(reprojection_losses, axis=3) combined = reprojection_losses if self.auto_mask: identity_reprojection_losses = [] for i in range(num_source): identity_reprojection_losses.append(self.compute_reprojection_loss(src_image_stack[:,:,:, 3*i:3*(i+1)], tgt_image)) identity_reprojection_losses = tf.concat(identity_reprojection_losses, axis=3) identity_reprojection_losses += (tf.random_normal(identity_reprojection_losses.get_shape()) * 1e-5) combined = tf.concat([identity_reprojection_losses, reprojection_losses], axis=3) pred_auto_masks.append(tf.expand_dims(tf.cast(tf.argmin(combined, axis=3) > 1,tf.float32) * 255, -1)) # pred_auto_masks1.append(tf.expand_dims(tf.cast(tf.argmin(tf.concat([combined[:,:,:,1:2],combined[:,:,:,3:4]],axis=3), axis=3), tf.float32) * 255,-1)) # pred_auto_masks2.append(tf.expand_dims(tf.cast( # tf.argmin(combined, axis=3) > 1, # tf.float32) * 255, -1)) reprojection_loss = tf.reduce_mean(tf.reduce_min(combined, axis=3)) pixel_losses += reprojection_loss smooth_loss = self.get_smooth_loss(pred_disp[s], tgt_image_pyramid[s]) smooth_losses += smooth_loss smooth_loss /= (2 ** s) scale_total_loss = reprojection_loss + self.smoothness_ratio * smooth_loss total_loss += scale_total_loss tgt_image_stack_all.append(tgt_image) src_image_stack_all.append(src_image_stack_aug) proj_image_stack_all.append(proj_image_stack) proj_error_stack_all.append(proj_error_stack) total_loss /= loader.num_scales pixel_losses /= loader.num_scales smooth_losses /= loader.num_scales with tf.name_scope('train_op'): self.total_step = self.total_epoch * loader.steps_per_epoch self.global_step = tf.Variable(0,name='global_step',trainable=False) learning_rates = [self.start_learning_rate, self.start_learning_rate / 10 ] boundaries = [np.int(self.total_step * 3 / 4)] self.learning_rate = tf.train.piecewise_constant(self.global_step, boundaries, learning_rates) optimizer = tf.train.AdamOptimizer(self.learning_rate, self.beta1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.train_op = optimizer.minimize(total_loss, global_step=self.global_step) self.incr_global_step = tf.assign(self.global_step,self.global_step + 1) # Collect tensors that are useful later (e.g. tf summary) self.pred_depth = pred_depth_rawscale self.pred_disp = pred_disp self.pred_poses = pred_poses self.steps_per_epoch = loader.steps_per_epoch self.total_loss = total_loss self.pixel_loss = pixel_losses self.smooth_loss = smooth_losses self.tgt_image_all = tgt_image_stack_all self.src_image_stack_all = src_image_stack_all self.proj_image_stack_all = proj_image_stack_all self.proj_error_stack_all = proj_error_stack_all if self.auto_mask: self.pred_auto_masks = pred_auto_masks