def direction_net_single(src_img, trt_img, rotation_gt, translation_gt): """Build the computation graph to train the DirectionNet-Single. Args: src_img: [BATCH, HEIGHT, WIDTH, 3] input source images. trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images. rotation_gt: [BATCH, 3, 3] ground truth rotation matrices. translation_gt: [BATCH, 3] ground truth translation directions. Returns: A collection of tensors including training ops, loss, and global step count. """ net = model.DirectionNet(4) global_step = tf.train.get_or_create_global_step() directions_gt = tf.concat([rotation_gt, translation_gt], 1) distribution_gt = util.spherical_normalization(util.von_mises_fisher( directions_gt, tf.constant(FLAGS.kappa, tf.float32), [FLAGS.distribution_height, FLAGS.distribution_width]), rectify=False) pred = net(src_img, trt_img, training=True) directions, expectation, distribution_pred = util.distributions_to_directions( pred) rotation_estimated = util.svd_orthogonalize(directions[:, :3]) direction_loss = losses.direction_loss(directions, directions_gt) distribution_loss = tf.constant(FLAGS.alpha, tf.float32) * losses.distribution_loss( distribution_pred, distribution_gt) spread_loss = tf.cast(FLAGS.beta, tf.float32) * losses.spread_loss(expectation) rotation_error = tf.reduce_mean( util.rotation_geodesic(rotation_estimated, rotation_gt)) translation_error = tf.reduce_mean( tf.acos( tf.clip_by_value( tf.reduce_sum(directions[:, -1] * directions_gt[:, -1], -1), -1., 1.))) direction_error = tf.reduce_mean( tf.acos( tf.clip_by_value(tf.reduce_sum(directions * directions_gt, -1), -1., 1.))) loss = direction_loss + distribution_loss + spread_loss tf.summary.scalar('loss', loss) tf.summary.scalar('distribution_loss', distribution_loss) tf.summary.scalar('spread_loss', spread_loss) tf.summary.scalar('direction_error', util.radians_to_degrees(direction_error)) tf.summary.scalar('rotation_error', util.radians_to_degrees(rotation_error)) tf.summary.scalar('translation_error', util.radians_to_degrees(translation_error)) for i in range(3): tf.summary.image('distribution/rotation/ground_truth_%d' % (i + 1), distribution_gt[:, :, :, i:i + 1], max_outputs=4) tf.summary.image('distribution/rotation/prediction_%d' % (i + 1), distribution_pred[:, :, :, i:i + 1], max_outputs=4) tf.summary.image('distribution/translation/ground_truth', distribution_gt[:, :, :, -1:], max_outputs=4) tf.summary.image('distribution/translation/prediction', distribution_pred[:, :, :, -1:], max_outputs=4) tf.summary.image('source_image', src_img, max_outputs=4) tf.summary.image('target_image', trt_img, max_outputs=4) optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr) train_op = optimizer.minimize(loss, global_step=global_step, name='train') update_op = net.updates return Computation(tf.group([train_op, update_op]), loss, global_step)
def direction_net_rotation(src_img, trt_img, rotation_gt, n_output_distributions=3): """Build the computation graph to train the DirectionNet-R. Args: src_img: [BATCH, HEIGHT, WIDTH, 3] input source images. trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images. rotation_gt: [BATCH, 3, 3] ground truth rotation matrices. n_output_distributions: (int) number of output distributions. (either two or three) The model uses 9D representation for rotations when it is 3 and the model uses 6D representation when it is 2. Returns: A collection of tensors including training ops, loss, and global step count. Raises: ValueError: 'n_output_distributions' must be either 2 or 3. """ if n_output_distributions != 3 and n_output_distributions != 2: raise ValueError("'n_output_distributions' must be either 2 or 3.") net = model.DirectionNet(n_output_distributions) global_step = tf.train.get_or_create_global_step() directions_gt = rotation_gt[:, :n_output_distributions] distribution_gt = util.spherical_normalization(util.von_mises_fisher( directions_gt, tf.constant(FLAGS.kappa, tf.float32), [FLAGS.distribution_height, FLAGS.distribution_width]), rectify=False) pred = net(src_img, trt_img, training=True) directions, expectation, distribution_pred = util.distributions_to_directions( pred) if n_output_distributions == 3: rotation_estimated = util.svd_orthogonalize(directions) elif n_output_distributions == 2: rotation_estimated = util.gram_schmidt(directions) direction_loss = losses.direction_loss(directions, directions_gt) distribution_loss = tf.constant(FLAGS.alpha, tf.float32) * losses.distribution_loss( distribution_pred, distribution_gt) spread_loss = tf.cast(FLAGS.beta, tf.float32) * losses.spread_loss(expectation) rotation_error = tf.reduce_mean( util.rotation_geodesic(rotation_estimated, rotation_gt)) direction_error = tf.reduce_mean( tf.acos( tf.clip_by_value(tf.reduce_sum(directions * directions_gt, -1), -1., 1.))) loss = direction_loss + distribution_loss + spread_loss tf.summary.scalar('loss', loss) tf.summary.scalar('distribution_loss', distribution_loss) tf.summary.scalar('spread_loss', spread_loss) tf.summary.scalar('direction_error', util.radians_to_degrees(direction_error)) tf.summary.scalar('rotation_error', util.radians_to_degrees(rotation_error)) for i in range(n_output_distributions): tf.summary.image('distribution/rotation/ground_truth_%d' % (i + 1), distribution_gt[:, :, :, i:i + 1], max_outputs=4) tf.summary.image('distribution/rotation/prediction_%d' % (i + 1), distribution_pred[:, :, :, i:i + 1], max_outputs=4) tf.summary.image('source_image', src_img, max_outputs=4) tf.summary.image('target_image', trt_img, max_outputs=4) optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr) train_op = optimizer.minimize(loss, global_step=global_step, name='train') update_op = net.updates return Computation(tf.group([train_op, update_op]), loss, global_step)
def direction_net_translation(src_img, trt_img, rotation_gt, translation_gt, fov_gt, rotation_pred, derotate_both=False): """Build the computation graph to train the DirectionNet-T. Args: src_img: [BATCH, HEIGHT, WIDTH, 3] input source images. trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images. rotation_gt: [BATCH, 3, 3] ground truth rotation matrices. translation_gt: [BATCH, 3] ground truth translation directions. fov_gt: [BATCH] the ground truth field of view (degrees) of input images. rotation_pred: [BATCH, 3, 3] estimated rotations from DirectionNet-R. derotate_both: (bool) transform both input images to a middle frame by half the relative rotation between them to cancel out the rotation if true. Otherwise, only derotate the target image to the source image's frame. Returns: A collection of tensors including training ops, loss, and global step count. """ net = model.DirectionNet(1) global_step = tf.train.get_or_create_global_step() perturbed_rotation = tf.cond( tf.less(tf.random_uniform([], 0, 1.0), 0.5), lambda: util.perturb_rotation(rotation_gt, [10., 5., 10.]), lambda: rotation_pred) (transformed_src, transformed_trt) = util.derotation( src_img, trt_img, perturbed_rotation, fov_gt, FLAGS.transformed_fov, [FLAGS.transformed_height, FLAGS.transformed_width], derotate_both) (transformed_src_gt, transformed_trt_gt) = util.derotation( src_img, trt_img, rotation_gt, fov_gt, FLAGS.transformed_fov, [FLAGS.transformed_height, FLAGS.transformed_width], derotate_both) half_derotation = util.half_rotation(perturbed_rotation) translation_gt = tf.squeeze( tf.matmul(half_derotation, tf.expand_dims(translation_gt, -1), transpose_a=True), -1) translation_gt = tf.expand_dims(translation_gt, 1) distribution_gt = util.spherical_normalization(util.von_mises_fisher( translation_gt, tf.constant(FLAGS.kappa, tf.float32), [FLAGS.distribution_height, FLAGS.distribution_width]), rectify=False) pred = net(transformed_src, transformed_trt, training=True) directions, expectation, distribution_pred = util.distributions_to_directions( pred) direction_loss = losses.direction_loss(directions, translation_gt) distribution_loss = tf.constant(FLAGS.alpha, tf.float32) * losses.distribution_loss( distribution_pred, distribution_gt) spread_loss = tf.cast(FLAGS.beta, tf.float32) * losses.spread_loss(expectation) direction_error = tf.reduce_mean( tf.acos( tf.clip_by_value(tf.reduce_sum(directions * translation_gt, -1), -1., 1.))) loss = direction_loss + distribution_loss + spread_loss tf.summary.scalar('loss', loss) tf.summary.scalar('distribution_loss', distribution_loss) tf.summary.scalar('spread_loss', spread_loss) tf.summary.scalar('direction_error', util.radians_to_degrees(direction_error)) tf.summary.image('distribution/translation/ground_truth', distribution_gt, max_outputs=4) tf.summary.image('distribution/translation/prediction', distribution_pred, max_outputs=4) tf.summary.image('source_image', src_img, max_outputs=4) tf.summary.image('target_image', trt_img, max_outputs=4) tf.summary.image('transformed_source_image', transformed_src, max_outputs=4) tf.summary.image('transformed_target_image', transformed_trt, max_outputs=4) tf.summary.image('transformed_source_image_gt', transformed_src_gt, max_outputs=4) tf.summary.image('transformed_target_image_gt', transformed_trt_gt, max_outputs=4) optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr) train_op = optimizer.minimize(loss, global_step=global_step, name='train') update_op = net.updates return Computation(tf.group([train_op, update_op]), loss, global_step)