示例#1
0
    def build_train_graph(self,
                          inputs,
                          min_depth,
                          max_depth,
                          cube_res,
                          theta_res,
                          phi_res,
                          r_res,
                          scale_factors,
                          num_mpi_planes,
                          learning_rate=0.0001,
                          vgg_model_weights=None,
                          global_step=0,
                          depth_clip=20.0):
        """Construct the training computation graph.

    Args:
      inputs: dictionary of tensors (see 'input_data' below) needed for training
      min_depth: minimum depth for the PSV and MPI planes
      max_depth: maximum depth for the PSV and MPI planes
      cube_res: per-side cube resolution
      theta_res: environment map width
      phi_res: environment map height
      r_res: number of radii to use when sampling spheres for rendering
      scale_factors: downsampling factors of cubes relative to the coarsest
      num_mpi_planes: number of MPI planes to infer
      learning_rate: learning rate
      vgg_model_weights: vgg weights (needed when vgg loss is used)
      global_step: training iteration
      depth_clip: maximum depth for coarsest resampled volumes

    Returns:
      A train_op to be used for training.
    """
        with tf.name_scope('setup'):
            psv_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)
            mpi_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)

        with tf.name_scope('input_data'):

            tgt_image = inputs['tgt_image']
            ref_image = inputs['ref_image']
            src_images = inputs['src_images']
            env_image = inputs['env_image']

            ref_depth = inputs['ref_depth']

            tgt_pose = inputs['tgt_pose']
            ref_pose = inputs['ref_pose']
            src_poses = inputs['src_poses']
            env_pose = inputs['env_pose']

            intrinsics = inputs['intrinsics']

            _, _, _, num_source = src_poses.get_shape().as_list()

        with tf.name_scope('inference'):
            num_mpi_planes = tf.shape(mpi_planes)[0]
            pred = self.infer_mpi(src_images, ref_image, ref_pose, src_poses,
                                  intrinsics, psv_planes)
            rgba_layers = pred['rgba_layers']
            psv = pred['psv']

        with tf.name_scope('synthesis'):
            output_image, output_alpha_acc, _ = self.mpi_render_view(
                rgba_layers, ref_pose, tgt_pose, mpi_planes, intrinsics)
        with tf.name_scope('environment_rendering'):
            mpi_gt = self.img2mpi(ref_image, ref_depth, mpi_planes)
            output_image_gt, _, _ = self.mpi_render_view(
                mpi_gt, ref_pose, tgt_pose, mpi_planes, intrinsics)

            lightvols_gt, _, _, _, _ = self.predict_lighting_vol(
                mpi_gt,
                mpi_planes,
                intrinsics,
                cube_res,
                scale_factors,
                depth_clip=depth_clip)

            lightvols, lightvol_centers, \
            lightvol_side_lengths, \
            cube_rel_shapes, \
            cube_nest_inds = self.predict_lighting_vol(rgba_layers, mpi_planes,
                                                       intrinsics, cube_res,
                                                       scale_factors,
                                                       depth_clip=depth_clip)

            lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes,
                                                   cube_nest_inds)

            gt_envmap, gt_shells = self.render_envmap(
                lightvols_gt, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

            prenet_envmap, prenet_shells = self.render_envmap(
                lightvols, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

            output_envmap, output_shells = self.render_envmap(
                lightvols_out, lightvol_centers, lightvol_side_lengths,
                cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
                phi_res, r_res)

        with tf.name_scope('loss'):
            # mask loss for pixels outside reference frustum
            loss_mask = tf.where(
                tf.equal(output_alpha_acc[Ellipsis, tf.newaxis], 0.0),
                tf.zeros_like(output_image[:, :, :, 0:1]),
                tf.ones_like(output_image[:, :, :, 0:1]))
            loss_mask = tf.stop_gradient(loss_mask)
            tf.summary.image('loss_mask', loss_mask)

            # helper functions for loss
            def compute_error(real, fake, mask):
                mask = tf.ones_like(real) * mask
                return tf.reduce_sum(mask * tf.abs(fake - real)) / (
                    tf.reduce_sum(mask) + 1.0e-8)

            # Normalized VGG loss
            def downsample(tensor, ds):
                return tf.nn.avg_pool(tensor, [1, ds, ds, 1], [1, ds, ds, 1],
                                      'SAME')

            def vgg_loss(tgt_image, output_image, loss_mask, vgg_weights):
                """VGG activation loss definition."""

                vgg_real = nets.build_vgg19(tgt_image * 255.0, vgg_weights)
                rescaled_output_image = output_image * 255.0
                vgg_fake = nets.build_vgg19(rescaled_output_image, vgg_weights)
                p0 = compute_error(vgg_real['input'], vgg_fake['input'],
                                   loss_mask)
                p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2'],
                                   loss_mask) / 2.6
                p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2'],
                                   downsample(loss_mask, 2)) / 4.8
                p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2'],
                                   downsample(loss_mask, 4)) / 3.7
                p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2'],
                                   downsample(loss_mask, 8)) / 5.6
                p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2'],
                                   downsample(loss_mask, 16)) * 10 / 1.5
                total_loss = p0 + p1 + p2 + p3 + p4 + p5
                return total_loss

            # rendered image loss
            render_loss = vgg_loss(tgt_image, output_image, loss_mask,
                                   vgg_model_weights) / 100.0
            total_loss = render_loss

            # rendered envmap loss
            envmap_loss = vgg_loss(env_image, output_envmap[Ellipsis, :3],
                                   tf.ones_like(env_image[Ellipsis, 0:1]),
                                   vgg_model_weights) / 100.0

            # set envmap loss to 0 when only training mpi network (see paper)
            envmap_loss = tf.where(tf.greater(global_step, 240000),
                                   envmap_loss, 0.0)

            total_loss += envmap_loss

            # adversarial loss for envmap
            real_logit = nets.discriminator(env_image, scope='discriminator')
            fake_logit = nets.discriminator(output_envmap[Ellipsis, :3],
                                            scope='discriminator')
            adv_loss_list = []
            for i in range(len(fake_logit)):
                adv_loss_list.append(0.1 * -1.0 *
                                     tf.reduce_mean(fake_logit[i][-1]))
            adv_loss = tf.reduce_mean(adv_loss_list)
            real_loss_list = []
            fake_loss_list = []
            for i in range(len(fake_logit)):
                real_loss_list.append(
                    -1.0 *
                    tf.reduce_mean(tf.minimum(real_logit[i][-1] - 1, 0.0)))
                fake_loss_list.append(-1.0 * tf.reduce_mean(
                    tf.minimum(-1.0 * fake_logit[i][-1] - 1, 0.0)))
            real_loss = tf.reduce_mean(real_loss_list)
            fake_loss = tf.reduce_mean(fake_loss_list)
            disc_loss = real_loss + fake_loss

            # set adv/disc losses to 0 until end of training
            adv_loss = tf.where(tf.greater(global_step, 690000), adv_loss, 0.0)
            disc_loss = tf.where(tf.greater(global_step, 690000), disc_loss,
                                 0.0)

            tf.summary.scalar('loss_disc', disc_loss)
            tf.summary.scalar('loss_disc_real', real_loss)
            tf.summary.scalar('loss_disc_fake', fake_loss)
            tf.summary.scalar('loss_adv', adv_loss)

            total_loss += adv_loss

        with tf.name_scope('train_op'):
            train_variables = [
                var for var in tf.trainable_variables()
                if 'discriminator' not in var.name
            ]
            optim = tf.train.AdamOptimizer(learning_rate, epsilon=1e-4)
            grads_and_variables = optim.compute_gradients(
                total_loss, var_list=train_variables)
            grads = [gv[0] for gv in grads_and_variables]
            variables = [gv[1] for gv in grads_and_variables]

            def denan(x):
                return tf.where(tf.is_nan(x), tf.zeros_like(x), x)

            grads_clipped = [denan(g) for g in grads]
            grads_clipped, _ = tf.clip_by_global_norm(grads_clipped, 100.0)
            train_op = [optim.apply_gradients(zip(grads_clipped, variables))]
            tf.summary.scalar('gradient global norm',
                              tf.linalg.global_norm(grads))
            tf.summary.scalar('clipped gradient global norm',
                              tf.linalg.global_norm(grads_clipped))

            d_variables = [
                var for var in tf.trainable_variables()
                if 'discriminator' in var.name
            ]
            optim_d = tf.train.AdamOptimizer(learning_rate, beta1=0.0)
            train_op.append(optim_d.minimize(disc_loss, var_list=d_variables))

        with tf.name_scope('envmap_gt'):
            tf.summary.image('envmap', gt_envmap)
            tf.summary.image('envmap_alpha', gt_envmap[Ellipsis, -1:])
            for i in range(len(gt_shells)):
                i_envmap = pj.over_composite(gt_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)
        with tf.name_scope('envmap_prenet'):
            tf.summary.image('envmap', prenet_envmap)
            tf.summary.image('envmap_alpha', prenet_envmap[Ellipsis, -1:])
            for i in range(len(prenet_shells)):
                i_envmap = pj.over_composite(prenet_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)
        with tf.name_scope('envmap_output'):
            tf.summary.image('envmap', output_envmap)
            tf.summary.image('envmap_alpha', output_envmap[Ellipsis, -1:])
            for i in range(len(output_shells)):
                i_envmap = pj.over_composite(output_shells[i])
                tf.summary.image('envmap_level_' + str(i), i_envmap)

        tf.summary.scalar('loss_total', total_loss)
        tf.summary.scalar('loss_render', render_loss)
        tf.summary.scalar('loss_envmap', envmap_loss)
        tf.summary.scalar('min_depth', min_depth)
        tf.summary.scalar('max_depth', max_depth)

        with tf.name_scope('level_stats'):
            for i in range(len(lightvols)):
                tf.summary.scalar('cube_side_length_' + str(i),
                                  lightvol_side_lengths[i])
                tf.summary.scalar('cube_center_' + str(i),
                                  lightvol_centers[i][0, -1])

        # Source images
        for i in range(num_source):
            src_image = src_images[:, :, :, i * 3:(i + 1) * 3]
            tf.summary.image('image_src_%d' % i, src_image)
        # Output image
        tf.summary.image('image_output', output_image)
        tf.summary.image('image_output_Gt', output_image_gt)
        # Target image
        tf.summary.image('image_tgt', tgt_image)
        tf.summary.image('envmap_tgt', env_image)
        # Ref image
        tf.summary.image('image_ref', ref_image)
        # Predicted color and alpha layers, and PSV
        num_summ = 8  # number of plane summaries to show in tensorboard
        for i in range(num_summ):
            ind = tf.to_int32(i * num_mpi_planes / num_summ)
            rgb = rgba_layers[:, :, :, ind, :3]
            alpha = rgba_layers[:, :, :, ind, -1:]
            ref_plane = psv[:, :, :, ind, :3]
            source_plane = psv[:, :, :, ind, 3:6]
            tf.summary.image('layer_rgb_%d' % i, rgb)
            tf.summary.image('layer_alpha_%d' % i, alpha)
            tf.summary.image('layer_rgba_%d' % i, rgba_layers[:, :, :, ind, :])
            tf.summary.image('psv_avg_%d' % i,
                             0.5 * ref_plane + 0.5 * source_plane)
            tf.summary.image('psv_ref_%d' % i, ref_plane)
            tf.summary.image('psv_source_%d' % i, source_plane)

        return train_op
示例#2
0
    def render_envmap(self, cubes, cube_centers, cube_side_lengths,
                      cube_rel_shapes, cube_nest_inds, ref_pose, env_pose,
                      theta_res, phi_res, r_res):
        """Render environment map from volumetric lights.

    Args:
      cubes: input list of cubes in multiscale volume
      cube_centers: position of cube centers
      cube_side_lengths: side lengths of cubes
      cube_rel_shapes: size of "footprint" of each cube within next coarser cube
      cube_nest_inds: indices for cube "footprints"
      ref_pose: c2w pose of ref camera
      env_pose: c2w pose of environment map camera
      theta_res: resolution of theta (width) for environment map
      phi_res: resolution of phi (height) for environment map
      r_res: number of spherical shells to sample for environment map rendering

    Returns:
      An environment map at the input pose
    """
        num_scales = len(cubes)

        env_c2w = env_pose
        env2ref = tf.matmul(tf.matrix_inverse(ref_pose), env_c2w)

        # cube-->sphere resampling
        all_shells_list = []
        all_rad_list = []
        for i in range(num_scales):
            if i == num_scales - 1:
                # "finest" resolution cube, don't zero out
                cube_removed = cubes[i]
            else:
                # zero out areas covered by finer resolution cubes
                cube_shape = cubes[i].get_shape().as_list()[1]

                zm_y, zm_x, zm_z = tf.meshgrid(
                    tf.range(cube_nest_inds[i][0],
                             cube_nest_inds[i][0] + cube_rel_shapes[i]),
                    tf.range(cube_nest_inds[i][1],
                             cube_nest_inds[i][1] + cube_rel_shapes[i]),
                    tf.range(cube_nest_inds[i][2],
                             cube_nest_inds[i][2] + cube_rel_shapes[i]),
                    indexing='ij')
                inds = tf.stack([zm_y, zm_x, zm_z], axis=-1)
                updates = tf.to_float(tf.ones_like(zm_x))
                zero_mask = 1.0 - tf.scatter_nd(
                    inds, updates, shape=[cube_shape, cube_shape, cube_shape])
                cube_removed = zero_mask[tf.newaxis, :, :, :,
                                         tf.newaxis] * cubes[i]

            spheres_i, rad_i = pj.spherical_cubevol_resample(
                cube_removed, env2ref, cube_centers[i], cube_side_lengths[i],
                phi_res, theta_res, r_res)
            all_shells_list.append(spheres_i)
            all_rad_list.append(rad_i)

        all_shells = tf.concat(all_shells_list, axis=3)
        all_rad = tf.concat(all_rad_list, axis=0)
        all_shells = pj.interleave_shells(all_shells, all_rad)
        all_shells_envmap = pj.over_composite(all_shells)

        return all_shells_envmap, all_shells_list