def main(argv): del argv # unused if FLAGS.checkpoint_dir is None: raise ValueError("`checkpoint_dir` must be defined") if FLAGS.data_dir is None: raise ValueError("`data_dir` must be defined") if FLAGS.output_dir is None: raise ValueError("`output_dir` must be defined") # Set up placeholders ref_image = tf.placeholder(dtype=tf.float32, shape=[None, height, width, 3]) ref_depth = tf.placeholder(dtype=tf.float32, shape=[None, height, width]) intrinsics = tf.placeholder(dtype=tf.float32, shape=[None, 3, 3]) ref_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4]) src_images = tf.placeholder(dtype=tf.float32, shape=[None, height, width, 3]) src_poses = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4, 1]) env_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4]) # Set up model model = MLV() # We use the true depth bounds for testing # Adjust to estimated bounds for your dataset min_depth = tf.reduce_min(ref_depth) max_depth = tf.reduce_max(ref_depth) # Set up graph mpi_planes = pj.inv_depths(min_depth, max_depth, num_planes) pred = model.infer_mpi(src_images, ref_image, ref_pose, src_poses, intrinsics, mpi_planes) rgba_layers = pred["rgba_layers"] lightvols, lightvol_centers, \ lightvol_side_lengths, \ cube_rel_shapes, \ cube_nest_inds = model.predict_lighting_vol(rgba_layers, mpi_planes, intrinsics, cube_res, scale_factors, depth_clip=depth_clip) lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes, cube_nest_inds) output_envmap, _ = model.render_envmap(lightvols_out, lightvol_centers, lightvol_side_lengths, cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res, phi_res, r_res) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) input_files = sorted( [f for f in os.listdir(FLAGS.data_dir) if f.endswith(".npz")]) print("found {:05d} input files".format(len(input_files))) with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, "model.ckpt")) for i in range(0, len(input_files)): print("running example:", i) # Load inputs batch = np.load(FLAGS.data_dir + input_files[i]) output_envmap_eval, = sess.run( [output_envmap], feed_dict={ ref_image: batch["ref_image"], ref_depth: batch["ref_depth"], intrinsics: batch["intrinsics"], ref_pose: batch["ref_pose"], src_images: batch["src_images"], src_poses: batch["src_poses"], env_pose: batch["env_pose"] }) # Write environment map image plt.imsave(os.path.join(FLAGS.output_dir, "{:05d}.png".format(i)), output_envmap_eval[0, :, :, :3])
def build_train_graph(self, inputs, min_depth, max_depth, cube_res, theta_res, phi_res, r_res, scale_factors, num_mpi_planes, learning_rate=0.0001, vgg_model_weights=None, global_step=0, depth_clip=20.0): """Construct the training computation graph. Args: inputs: dictionary of tensors (see 'input_data' below) needed for training min_depth: minimum depth for the PSV and MPI planes max_depth: maximum depth for the PSV and MPI planes cube_res: per-side cube resolution theta_res: environment map width phi_res: environment map height r_res: number of radii to use when sampling spheres for rendering scale_factors: downsampling factors of cubes relative to the coarsest num_mpi_planes: number of MPI planes to infer learning_rate: learning rate vgg_model_weights: vgg weights (needed when vgg loss is used) global_step: training iteration depth_clip: maximum depth for coarsest resampled volumes Returns: A train_op to be used for training. """ with tf.name_scope('setup'): psv_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes) mpi_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes) with tf.name_scope('input_data'): tgt_image = inputs['tgt_image'] ref_image = inputs['ref_image'] src_images = inputs['src_images'] env_image = inputs['env_image'] ref_depth = inputs['ref_depth'] tgt_pose = inputs['tgt_pose'] ref_pose = inputs['ref_pose'] src_poses = inputs['src_poses'] env_pose = inputs['env_pose'] intrinsics = inputs['intrinsics'] _, _, _, num_source = src_poses.get_shape().as_list() with tf.name_scope('inference'): num_mpi_planes = tf.shape(mpi_planes)[0] pred = self.infer_mpi(src_images, ref_image, ref_pose, src_poses, intrinsics, psv_planes) rgba_layers = pred['rgba_layers'] psv = pred['psv'] with tf.name_scope('synthesis'): output_image, output_alpha_acc, _ = self.mpi_render_view( rgba_layers, ref_pose, tgt_pose, mpi_planes, intrinsics) with tf.name_scope('environment_rendering'): mpi_gt = self.img2mpi(ref_image, ref_depth, mpi_planes) output_image_gt, _, _ = self.mpi_render_view( mpi_gt, ref_pose, tgt_pose, mpi_planes, intrinsics) lightvols_gt, _, _, _, _ = self.predict_lighting_vol( mpi_gt, mpi_planes, intrinsics, cube_res, scale_factors, depth_clip=depth_clip) lightvols, lightvol_centers, \ lightvol_side_lengths, \ cube_rel_shapes, \ cube_nest_inds = self.predict_lighting_vol(rgba_layers, mpi_planes, intrinsics, cube_res, scale_factors, depth_clip=depth_clip) lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes, cube_nest_inds) gt_envmap, gt_shells = self.render_envmap( lightvols_gt, lightvol_centers, lightvol_side_lengths, cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res, phi_res, r_res) prenet_envmap, prenet_shells = self.render_envmap( lightvols, lightvol_centers, lightvol_side_lengths, cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res, phi_res, r_res) output_envmap, output_shells = self.render_envmap( lightvols_out, lightvol_centers, lightvol_side_lengths, cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res, phi_res, r_res) with tf.name_scope('loss'): # mask loss for pixels outside reference frustum loss_mask = tf.where( tf.equal(output_alpha_acc[Ellipsis, tf.newaxis], 0.0), tf.zeros_like(output_image[:, :, :, 0:1]), tf.ones_like(output_image[:, :, :, 0:1])) loss_mask = tf.stop_gradient(loss_mask) tf.summary.image('loss_mask', loss_mask) # helper functions for loss def compute_error(real, fake, mask): mask = tf.ones_like(real) * mask return tf.reduce_sum(mask * tf.abs(fake - real)) / ( tf.reduce_sum(mask) + 1.0e-8) # Normalized VGG loss def downsample(tensor, ds): return tf.nn.avg_pool(tensor, [1, ds, ds, 1], [1, ds, ds, 1], 'SAME') def vgg_loss(tgt_image, output_image, loss_mask, vgg_weights): """VGG activation loss definition.""" vgg_real = nets.build_vgg19(tgt_image * 255.0, vgg_weights) rescaled_output_image = output_image * 255.0 vgg_fake = nets.build_vgg19(rescaled_output_image, vgg_weights) p0 = compute_error(vgg_real['input'], vgg_fake['input'], loss_mask) p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2'], loss_mask) / 2.6 p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2'], downsample(loss_mask, 2)) / 4.8 p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2'], downsample(loss_mask, 4)) / 3.7 p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2'], downsample(loss_mask, 8)) / 5.6 p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2'], downsample(loss_mask, 16)) * 10 / 1.5 total_loss = p0 + p1 + p2 + p3 + p4 + p5 return total_loss # rendered image loss render_loss = vgg_loss(tgt_image, output_image, loss_mask, vgg_model_weights) / 100.0 total_loss = render_loss # rendered envmap loss envmap_loss = vgg_loss(env_image, output_envmap[Ellipsis, :3], tf.ones_like(env_image[Ellipsis, 0:1]), vgg_model_weights) / 100.0 # set envmap loss to 0 when only training mpi network (see paper) envmap_loss = tf.where(tf.greater(global_step, 240000), envmap_loss, 0.0) total_loss += envmap_loss # adversarial loss for envmap real_logit = nets.discriminator(env_image, scope='discriminator') fake_logit = nets.discriminator(output_envmap[Ellipsis, :3], scope='discriminator') adv_loss_list = [] for i in range(len(fake_logit)): adv_loss_list.append(0.1 * -1.0 * tf.reduce_mean(fake_logit[i][-1])) adv_loss = tf.reduce_mean(adv_loss_list) real_loss_list = [] fake_loss_list = [] for i in range(len(fake_logit)): real_loss_list.append( -1.0 * tf.reduce_mean(tf.minimum(real_logit[i][-1] - 1, 0.0))) fake_loss_list.append(-1.0 * tf.reduce_mean( tf.minimum(-1.0 * fake_logit[i][-1] - 1, 0.0))) real_loss = tf.reduce_mean(real_loss_list) fake_loss = tf.reduce_mean(fake_loss_list) disc_loss = real_loss + fake_loss # set adv/disc losses to 0 until end of training adv_loss = tf.where(tf.greater(global_step, 690000), adv_loss, 0.0) disc_loss = tf.where(tf.greater(global_step, 690000), disc_loss, 0.0) tf.summary.scalar('loss_disc', disc_loss) tf.summary.scalar('loss_disc_real', real_loss) tf.summary.scalar('loss_disc_fake', fake_loss) tf.summary.scalar('loss_adv', adv_loss) total_loss += adv_loss with tf.name_scope('train_op'): train_variables = [ var for var in tf.trainable_variables() if 'discriminator' not in var.name ] optim = tf.train.AdamOptimizer(learning_rate, epsilon=1e-4) grads_and_variables = optim.compute_gradients( total_loss, var_list=train_variables) grads = [gv[0] for gv in grads_and_variables] variables = [gv[1] for gv in grads_and_variables] def denan(x): return tf.where(tf.is_nan(x), tf.zeros_like(x), x) grads_clipped = [denan(g) for g in grads] grads_clipped, _ = tf.clip_by_global_norm(grads_clipped, 100.0) train_op = [optim.apply_gradients(zip(grads_clipped, variables))] tf.summary.scalar('gradient global norm', tf.linalg.global_norm(grads)) tf.summary.scalar('clipped gradient global norm', tf.linalg.global_norm(grads_clipped)) d_variables = [ var for var in tf.trainable_variables() if 'discriminator' in var.name ] optim_d = tf.train.AdamOptimizer(learning_rate, beta1=0.0) train_op.append(optim_d.minimize(disc_loss, var_list=d_variables)) with tf.name_scope('envmap_gt'): tf.summary.image('envmap', gt_envmap) tf.summary.image('envmap_alpha', gt_envmap[Ellipsis, -1:]) for i in range(len(gt_shells)): i_envmap = pj.over_composite(gt_shells[i]) tf.summary.image('envmap_level_' + str(i), i_envmap) with tf.name_scope('envmap_prenet'): tf.summary.image('envmap', prenet_envmap) tf.summary.image('envmap_alpha', prenet_envmap[Ellipsis, -1:]) for i in range(len(prenet_shells)): i_envmap = pj.over_composite(prenet_shells[i]) tf.summary.image('envmap_level_' + str(i), i_envmap) with tf.name_scope('envmap_output'): tf.summary.image('envmap', output_envmap) tf.summary.image('envmap_alpha', output_envmap[Ellipsis, -1:]) for i in range(len(output_shells)): i_envmap = pj.over_composite(output_shells[i]) tf.summary.image('envmap_level_' + str(i), i_envmap) tf.summary.scalar('loss_total', total_loss) tf.summary.scalar('loss_render', render_loss) tf.summary.scalar('loss_envmap', envmap_loss) tf.summary.scalar('min_depth', min_depth) tf.summary.scalar('max_depth', max_depth) with tf.name_scope('level_stats'): for i in range(len(lightvols)): tf.summary.scalar('cube_side_length_' + str(i), lightvol_side_lengths[i]) tf.summary.scalar('cube_center_' + str(i), lightvol_centers[i][0, -1]) # Source images for i in range(num_source): src_image = src_images[:, :, :, i * 3:(i + 1) * 3] tf.summary.image('image_src_%d' % i, src_image) # Output image tf.summary.image('image_output', output_image) tf.summary.image('image_output_Gt', output_image_gt) # Target image tf.summary.image('image_tgt', tgt_image) tf.summary.image('envmap_tgt', env_image) # Ref image tf.summary.image('image_ref', ref_image) # Predicted color and alpha layers, and PSV num_summ = 8 # number of plane summaries to show in tensorboard for i in range(num_summ): ind = tf.to_int32(i * num_mpi_planes / num_summ) rgb = rgba_layers[:, :, :, ind, :3] alpha = rgba_layers[:, :, :, ind, -1:] ref_plane = psv[:, :, :, ind, :3] source_plane = psv[:, :, :, ind, 3:6] tf.summary.image('layer_rgb_%d' % i, rgb) tf.summary.image('layer_alpha_%d' % i, alpha) tf.summary.image('layer_rgba_%d' % i, rgba_layers[:, :, :, ind, :]) tf.summary.image('psv_avg_%d' % i, 0.5 * ref_plane + 0.5 * source_plane) tf.summary.image('psv_ref_%d' % i, ref_plane) tf.summary.image('psv_source_%d' % i, source_plane) return train_op