def main(argv):

    del argv  # unused

    if FLAGS.checkpoint_dir is None:
        raise ValueError("`checkpoint_dir` must be defined")
    if FLAGS.data_dir is None:
        raise ValueError("`data_dir` must be defined")
    if FLAGS.output_dir is None:
        raise ValueError("`output_dir` must be defined")

    # Set up placeholders
    ref_image = tf.placeholder(dtype=tf.float32,
                               shape=[None, height, width, 3])
    ref_depth = tf.placeholder(dtype=tf.float32, shape=[None, height, width])
    intrinsics = tf.placeholder(dtype=tf.float32, shape=[None, 3, 3])
    ref_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4])
    src_images = tf.placeholder(dtype=tf.float32,
                                shape=[None, height, width, 3])
    src_poses = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4, 1])
    env_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4])

    # Set up model
    model = MLV()

    # We use the true depth bounds for testing
    # Adjust to estimated bounds for your dataset
    min_depth = tf.reduce_min(ref_depth)
    max_depth = tf.reduce_max(ref_depth)

    # Set up graph
    mpi_planes = pj.inv_depths(min_depth, max_depth, num_planes)

    pred = model.infer_mpi(src_images, ref_image, ref_pose, src_poses,
                           intrinsics, mpi_planes)
    rgba_layers = pred["rgba_layers"]

    lightvols, lightvol_centers, \
    lightvol_side_lengths, \
    cube_rel_shapes, \
    cube_nest_inds = model.predict_lighting_vol(rgba_layers, mpi_planes,
                                                intrinsics, cube_res,
                                                scale_factors,
                                                depth_clip=depth_clip)
    lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes,
                                           cube_nest_inds)
    output_envmap, _ = model.render_envmap(lightvols_out, lightvol_centers,
                                           lightvol_side_lengths,
                                           cube_rel_shapes, cube_nest_inds,
                                           ref_pose, env_pose, theta_res,
                                           phi_res, r_res)

    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)

    input_files = sorted(
        [f for f in os.listdir(FLAGS.data_dir) if f.endswith(".npz")])
    print("found {:05d} input files".format(len(input_files)))

    with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, "model.ckpt"))

        for i in range(0, len(input_files)):
            print("running example:", i)

            # Load inputs
            batch = np.load(FLAGS.data_dir + input_files[i])

            output_envmap_eval, = sess.run(
                [output_envmap],
                feed_dict={
                    ref_image: batch["ref_image"],
                    ref_depth: batch["ref_depth"],
                    intrinsics: batch["intrinsics"],
                    ref_pose: batch["ref_pose"],
                    src_images: batch["src_images"],
                    src_poses: batch["src_poses"],
                    env_pose: batch["env_pose"]
                })

            # Write environment map image
            plt.imsave(os.path.join(FLAGS.output_dir, "{:05d}.png".format(i)),
                       output_envmap_eval[0, :, :, :3])
Ejemplo n.º 2
0
    def build_train_graph(self,
                          inputs,
                          min_depth,
                          max_depth,
                          cube_res,
                          theta_res,
                          phi_res,
                          r_res,
                          scale_factors,
                          num_mpi_planes,
                          learning_rate=0.0001,
                          vgg_model_weights=None,
                          global_step=0,
                          depth_clip=20.0):
        """Construct the training computation graph.

    Args:
      inputs: dictionary of tensors (see 'input_data' below) needed for training
      min_depth: minimum depth for the PSV and MPI planes
      max_depth: maximum depth for the PSV and MPI planes
      cube_res: per-side cube resolution
      theta_res: environment map width
      phi_res: environment map height
      r_res: number of radii to use when sampling spheres for rendering
      scale_factors: downsampling factors of cubes relative to the coarsest
      num_mpi_planes: number of MPI planes to infer
      learning_rate: learning rate
      vgg_model_weights: vgg weights (needed when vgg loss is used)
      global_step: training iteration
      depth_clip: maximum depth for coarsest resampled volumes

    Returns:
      A train_op to be used for training.
    """
        with tf.name_scope('setup'):
            psv_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)
            mpi_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)

        with tf.name_scope('input_data'):

            tgt_image = inputs['tgt_image']
            ref_image = inputs['ref_image']
            src_images = inputs['src_images']
            env_image = inputs['env_image']

            ref_depth = inputs['ref_depth']

            tgt_pose = inputs['tgt_pose']
            ref_pose = inputs['ref_pose']
            src_poses = inputs['src_poses']
            env_pose = inputs['env_pose']

            intrinsics = inputs['intrinsics']

            _, _, _, num_source = src_poses.get_shape().as_list()

        with tf.name_scope('inference'):
            num_mpi_planes = tf.shape(mpi_planes)[0]
            pred = self.infer_mpi(src_images, ref_image, ref_pose, src_poses,
                                  intrinsics, psv_planes)
            rgba_layers = pred['rgba_layers']
            psv = pred['psv']

        with tf.name_scope('synthesis'):
            output_image, output_alpha_acc, _ = self.mpi_render_view(
                rgba_layers, ref_pose, tgt_pose, mpi_planes, intrinsics)
        with tf.name_scope('environment_rendering'):
            mpi_gt = self.img2mpi(ref_image, ref_depth, mpi_planes)
            output_image_gt, _, _ = self.mpi_render_view(
                mpi_gt, ref_pose, tgt_pose, mpi_planes, intrinsics)

            lightvols_gt, _, _, _, _ = self.predict_lighting_vol(
                mpi_gt,
                mpi_planes,
                intrinsics,
                cube_res,
                scale_factors,
                depth_clip=depth_clip)

            lightvols, lightvol_centers, \
            lightvol_side_lengths, \
            cube_rel_shapes, \
            cube_nest_inds = self.predict_lighting_vol(rgba_layers, mpi_planes,
                                                       intrinsics, cube_res,
                                                       scale_factors,
                                                       depth_clip=depth_clip)

            lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes,
                                                   cube_nest_inds)

            gt_envmap, gt_shells = self.render_envmap(
                lightvols_gt, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

            prenet_envmap, prenet_shells = self.render_envmap(
                lightvols, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

            output_envmap, output_shells = self.render_envmap(
                lightvols_out, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

        with tf.name_scope('loss'):
            # mask loss for pixels outside reference frustum
            loss_mask = tf.where(
                tf.equal(output_alpha_acc[Ellipsis, tf.newaxis], 0.0),
                tf.zeros_like(output_image[:, :, :, 0:1]),
                tf.ones_like(output_image[:, :, :, 0:1]))
            loss_mask = tf.stop_gradient(loss_mask)
            tf.summary.image('loss_mask', loss_mask)

            # helper functions for loss
            def compute_error(real, fake, mask):
                mask = tf.ones_like(real) * mask
                return tf.reduce_sum(mask * tf.abs(fake - real)) / (
                    tf.reduce_sum(mask) + 1.0e-8)

            # Normalized VGG loss
            def downsample(tensor, ds):
                return tf.nn.avg_pool(tensor, [1, ds, ds, 1], [1, ds, ds, 1],
                                      'SAME')

            def vgg_loss(tgt_image, output_image, loss_mask, vgg_weights):
                """VGG activation loss definition."""

                vgg_real = nets.build_vgg19(tgt_image * 255.0, vgg_weights)
                rescaled_output_image = output_image * 255.0
                vgg_fake = nets.build_vgg19(rescaled_output_image, vgg_weights)
                p0 = compute_error(vgg_real['input'], vgg_fake['input'],
                                   loss_mask)
                p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2'],
                                   loss_mask) / 2.6
                p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2'],
                                   downsample(loss_mask, 2)) / 4.8
                p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2'],
                                   downsample(loss_mask, 4)) / 3.7
                p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2'],
                                   downsample(loss_mask, 8)) / 5.6
                p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2'],
                                   downsample(loss_mask, 16)) * 10 / 1.5
                total_loss = p0 + p1 + p2 + p3 + p4 + p5
                return total_loss

            # rendered image loss
            render_loss = vgg_loss(tgt_image, output_image, loss_mask,
                                   vgg_model_weights) / 100.0
            total_loss = render_loss

            # rendered envmap loss
            envmap_loss = vgg_loss(env_image, output_envmap[Ellipsis, :3],
                                   tf.ones_like(env_image[Ellipsis, 0:1]),
                                   vgg_model_weights) / 100.0

            # set envmap loss to 0 when only training mpi network (see paper)
            envmap_loss = tf.where(tf.greater(global_step, 240000),
                                   envmap_loss, 0.0)

            total_loss += envmap_loss

            # adversarial loss for envmap
            real_logit = nets.discriminator(env_image, scope='discriminator')
            fake_logit = nets.discriminator(output_envmap[Ellipsis, :3],
                                            scope='discriminator')
            adv_loss_list = []
            for i in range(len(fake_logit)):
                adv_loss_list.append(0.1 * -1.0 *
                                     tf.reduce_mean(fake_logit[i][-1]))
            adv_loss = tf.reduce_mean(adv_loss_list)
            real_loss_list = []
            fake_loss_list = []
            for i in range(len(fake_logit)):
                real_loss_list.append(
                    -1.0 *
                    tf.reduce_mean(tf.minimum(real_logit[i][-1] - 1, 0.0)))
                fake_loss_list.append(-1.0 * tf.reduce_mean(
                    tf.minimum(-1.0 * fake_logit[i][-1] - 1, 0.0)))
            real_loss = tf.reduce_mean(real_loss_list)
            fake_loss = tf.reduce_mean(fake_loss_list)
            disc_loss = real_loss + fake_loss

            # set adv/disc losses to 0 until end of training
            adv_loss = tf.where(tf.greater(global_step, 690000), adv_loss, 0.0)
            disc_loss = tf.where(tf.greater(global_step, 690000), disc_loss,
                                 0.0)

            tf.summary.scalar('loss_disc', disc_loss)
            tf.summary.scalar('loss_disc_real', real_loss)
            tf.summary.scalar('loss_disc_fake', fake_loss)
            tf.summary.scalar('loss_adv', adv_loss)

            total_loss += adv_loss

        with tf.name_scope('train_op'):
            train_variables = [
                var for var in tf.trainable_variables()
                if 'discriminator' not in var.name
            ]
            optim = tf.train.AdamOptimizer(learning_rate, epsilon=1e-4)
            grads_and_variables = optim.compute_gradients(
                total_loss, var_list=train_variables)
            grads = [gv[0] for gv in grads_and_variables]
            variables = [gv[1] for gv in grads_and_variables]

            def denan(x):
                return tf.where(tf.is_nan(x), tf.zeros_like(x), x)

            grads_clipped = [denan(g) for g in grads]
            grads_clipped, _ = tf.clip_by_global_norm(grads_clipped, 100.0)
            train_op = [optim.apply_gradients(zip(grads_clipped, variables))]
            tf.summary.scalar('gradient global norm',
                              tf.linalg.global_norm(grads))
            tf.summary.scalar('clipped gradient global norm',
                              tf.linalg.global_norm(grads_clipped))

            d_variables = [
                var for var in tf.trainable_variables()
                if 'discriminator' in var.name
            ]
            optim_d = tf.train.AdamOptimizer(learning_rate, beta1=0.0)
            train_op.append(optim_d.minimize(disc_loss, var_list=d_variables))

        with tf.name_scope('envmap_gt'):
            tf.summary.image('envmap', gt_envmap)
            tf.summary.image('envmap_alpha', gt_envmap[Ellipsis, -1:])
            for i in range(len(gt_shells)):
                i_envmap = pj.over_composite(gt_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)
        with tf.name_scope('envmap_prenet'):
            tf.summary.image('envmap', prenet_envmap)
            tf.summary.image('envmap_alpha', prenet_envmap[Ellipsis, -1:])
            for i in range(len(prenet_shells)):
                i_envmap = pj.over_composite(prenet_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)
        with tf.name_scope('envmap_output'):
            tf.summary.image('envmap', output_envmap)
            tf.summary.image('envmap_alpha', output_envmap[Ellipsis, -1:])
            for i in range(len(output_shells)):
                i_envmap = pj.over_composite(output_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)

        tf.summary.scalar('loss_total', total_loss)
        tf.summary.scalar('loss_render', render_loss)
        tf.summary.scalar('loss_envmap', envmap_loss)
        tf.summary.scalar('min_depth', min_depth)
        tf.summary.scalar('max_depth', max_depth)

        with tf.name_scope('level_stats'):
            for i in range(len(lightvols)):
                tf.summary.scalar('cube_side_length_' + str(i),
                                  lightvol_side_lengths[i])
                tf.summary.scalar('cube_center_' + str(i),
                                  lightvol_centers[i][0, -1])

        # Source images
        for i in range(num_source):
            src_image = src_images[:, :, :, i * 3:(i + 1) * 3]
            tf.summary.image('image_src_%d' % i, src_image)
        # Output image
        tf.summary.image('image_output', output_image)
        tf.summary.image('image_output_Gt', output_image_gt)
        # Target image
        tf.summary.image('image_tgt', tgt_image)
        tf.summary.image('envmap_tgt', env_image)
        # Ref image
        tf.summary.image('image_ref', ref_image)
        # Predicted color and alpha layers, and PSV
        num_summ = 8  # number of plane summaries to show in tensorboard
        for i in range(num_summ):
            ind = tf.to_int32(i * num_mpi_planes / num_summ)
            rgb = rgba_layers[:, :, :, ind, :3]
            alpha = rgba_layers[:, :, :, ind, -1:]
            ref_plane = psv[:, :, :, ind, :3]
            source_plane = psv[:, :, :, ind, 3:6]
            tf.summary.image('layer_rgb_%d' % i, rgb)
            tf.summary.image('layer_alpha_%d' % i, alpha)
            tf.summary.image('layer_rgba_%d' % i, rgba_layers[:, :, :, ind, :])
            tf.summary.image('psv_avg_%d' % i,
                             0.5 * ref_plane + 0.5 * source_plane)
            tf.summary.image('psv_ref_%d' % i, ref_plane)
            tf.summary.image('psv_source_%d' % i, source_plane)

        return train_op