Ejemplo n.º 1
0
def train_full_step(model_f, model_g, optimizer, x_all, y_all,
                    loss_distr_scale, gmm_scale):
    with tf.GradientTape() as tape_f0:
        f_x = model_f(x_all)
        dist_loss_f = loss_distr_scale * losses.distribution_loss(
            y_all, f_x, gmm_scale)
    gradients_target_fx = tape_f0.gradient(dist_loss_f,
                                           model_f.trainable_variables)
    with tf.GradientTape() as tape_g0:
        g_y = model_g(y_all)
        dist_loss_g = loss_distr_scale * losses.distribution_loss(
            x_all, g_y, gmm_scale)
    gradients_target_gy = tape_g0.gradient(dist_loss_g,
                                           model_g.trainable_variables)
    # Combine gradients.
    optimizer.apply_gradients(
        zip(gradients_target_fx, model_f.trainable_variables))
    optimizer.apply_gradients(
        zip(gradients_target_gy, model_g.trainable_variables))
    return dist_loss_f, dist_loss_g
Ejemplo n.º 2
0
def direction_net_single(src_img, trt_img, rotation_gt, translation_gt):
    """Build the computation graph to train the DirectionNet-Single.

  Args:
    src_img: [BATCH, HEIGHT, WIDTH, 3] input source images.
    trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images.
    rotation_gt: [BATCH, 3, 3] ground truth rotation matrices.
    translation_gt: [BATCH, 3] ground truth translation directions.

  Returns:
    A collection of tensors including training ops, loss, and global step count.
  """
    net = model.DirectionNet(4)
    global_step = tf.train.get_or_create_global_step()
    directions_gt = tf.concat([rotation_gt, translation_gt], 1)
    distribution_gt = util.spherical_normalization(util.von_mises_fisher(
        directions_gt, tf.constant(FLAGS.kappa, tf.float32),
        [FLAGS.distribution_height, FLAGS.distribution_width]),
                                                   rectify=False)

    pred = net(src_img, trt_img, training=True)
    directions, expectation, distribution_pred = util.distributions_to_directions(
        pred)
    rotation_estimated = util.svd_orthogonalize(directions[:, :3])

    direction_loss = losses.direction_loss(directions, directions_gt)
    distribution_loss = tf.constant(FLAGS.alpha,
                                    tf.float32) * losses.distribution_loss(
                                        distribution_pred, distribution_gt)
    spread_loss = tf.cast(FLAGS.beta,
                          tf.float32) * losses.spread_loss(expectation)
    rotation_error = tf.reduce_mean(
        util.rotation_geodesic(rotation_estimated, rotation_gt))
    translation_error = tf.reduce_mean(
        tf.acos(
            tf.clip_by_value(
                tf.reduce_sum(directions[:, -1] * directions_gt[:, -1], -1),
                -1., 1.)))
    direction_error = tf.reduce_mean(
        tf.acos(
            tf.clip_by_value(tf.reduce_sum(directions * directions_gt, -1),
                             -1., 1.)))

    loss = direction_loss + distribution_loss + spread_loss

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('distribution_loss', distribution_loss)
    tf.summary.scalar('spread_loss', spread_loss)
    tf.summary.scalar('direction_error',
                      util.radians_to_degrees(direction_error))
    tf.summary.scalar('rotation_error',
                      util.radians_to_degrees(rotation_error))
    tf.summary.scalar('translation_error',
                      util.radians_to_degrees(translation_error))

    for i in range(3):
        tf.summary.image('distribution/rotation/ground_truth_%d' % (i + 1),
                         distribution_gt[:, :, :, i:i + 1],
                         max_outputs=4)
        tf.summary.image('distribution/rotation/prediction_%d' % (i + 1),
                         distribution_pred[:, :, :, i:i + 1],
                         max_outputs=4)

    tf.summary.image('distribution/translation/ground_truth',
                     distribution_gt[:, :, :, -1:],
                     max_outputs=4)
    tf.summary.image('distribution/translation/prediction',
                     distribution_pred[:, :, :, -1:],
                     max_outputs=4)

    tf.summary.image('source_image', src_img, max_outputs=4)
    tf.summary.image('target_image', trt_img, max_outputs=4)

    optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr)
    train_op = optimizer.minimize(loss, global_step=global_step, name='train')
    update_op = net.updates
    return Computation(tf.group([train_op, update_op]), loss, global_step)
Ejemplo n.º 3
0
def direction_net_rotation(src_img,
                           trt_img,
                           rotation_gt,
                           n_output_distributions=3):
    """Build the computation graph to train the DirectionNet-R.

  Args:
    src_img: [BATCH, HEIGHT, WIDTH, 3] input source images.
    trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images.
    rotation_gt: [BATCH, 3, 3] ground truth rotation matrices.
    n_output_distributions: (int) number of output distributions. (either two or
    three) The model uses 9D representation for rotations when it is 3 and the
    model uses 6D representation when it is 2.

  Returns:
    A collection of tensors including training ops, loss, and global step count.

  Raises:
    ValueError: 'n_output_distributions' must be either 2 or 3.
  """
    if n_output_distributions != 3 and n_output_distributions != 2:
        raise ValueError("'n_output_distributions' must be either 2 or 3.")

    net = model.DirectionNet(n_output_distributions)
    global_step = tf.train.get_or_create_global_step()
    directions_gt = rotation_gt[:, :n_output_distributions]
    distribution_gt = util.spherical_normalization(util.von_mises_fisher(
        directions_gt, tf.constant(FLAGS.kappa, tf.float32),
        [FLAGS.distribution_height, FLAGS.distribution_width]),
                                                   rectify=False)

    pred = net(src_img, trt_img, training=True)
    directions, expectation, distribution_pred = util.distributions_to_directions(
        pred)
    if n_output_distributions == 3:
        rotation_estimated = util.svd_orthogonalize(directions)
    elif n_output_distributions == 2:
        rotation_estimated = util.gram_schmidt(directions)

    direction_loss = losses.direction_loss(directions, directions_gt)
    distribution_loss = tf.constant(FLAGS.alpha,
                                    tf.float32) * losses.distribution_loss(
                                        distribution_pred, distribution_gt)
    spread_loss = tf.cast(FLAGS.beta,
                          tf.float32) * losses.spread_loss(expectation)
    rotation_error = tf.reduce_mean(
        util.rotation_geodesic(rotation_estimated, rotation_gt))
    direction_error = tf.reduce_mean(
        tf.acos(
            tf.clip_by_value(tf.reduce_sum(directions * directions_gt, -1),
                             -1., 1.)))

    loss = direction_loss + distribution_loss + spread_loss

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('distribution_loss', distribution_loss)
    tf.summary.scalar('spread_loss', spread_loss)
    tf.summary.scalar('direction_error',
                      util.radians_to_degrees(direction_error))
    tf.summary.scalar('rotation_error',
                      util.radians_to_degrees(rotation_error))

    for i in range(n_output_distributions):
        tf.summary.image('distribution/rotation/ground_truth_%d' % (i + 1),
                         distribution_gt[:, :, :, i:i + 1],
                         max_outputs=4)
        tf.summary.image('distribution/rotation/prediction_%d' % (i + 1),
                         distribution_pred[:, :, :, i:i + 1],
                         max_outputs=4)

    tf.summary.image('source_image', src_img, max_outputs=4)
    tf.summary.image('target_image', trt_img, max_outputs=4)

    optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr)
    train_op = optimizer.minimize(loss, global_step=global_step, name='train')
    update_op = net.updates
    return Computation(tf.group([train_op, update_op]), loss, global_step)
Ejemplo n.º 4
0
def direction_net_translation(src_img,
                              trt_img,
                              rotation_gt,
                              translation_gt,
                              fov_gt,
                              rotation_pred,
                              derotate_both=False):
    """Build the computation graph to train the DirectionNet-T.

  Args:
    src_img: [BATCH, HEIGHT, WIDTH, 3] input source images.
    trt_img: [BATCH, HEIGHT, WIDTH, 3] input target images.
    rotation_gt: [BATCH, 3, 3] ground truth rotation matrices.
    translation_gt: [BATCH, 3] ground truth translation directions.
    fov_gt: [BATCH] the ground truth field of view (degrees) of input images.
    rotation_pred: [BATCH, 3, 3] estimated rotations from DirectionNet-R.
    derotate_both: (bool) transform both input images to a middle frame by half
      the relative rotation between them to cancel out the rotation if true.
      Otherwise, only derotate the target image to the source image's frame.

  Returns:
    A collection of tensors including training ops, loss, and global step count.
  """
    net = model.DirectionNet(1)
    global_step = tf.train.get_or_create_global_step()
    perturbed_rotation = tf.cond(
        tf.less(tf.random_uniform([], 0, 1.0), 0.5),
        lambda: util.perturb_rotation(rotation_gt, [10., 5., 10.]),
        lambda: rotation_pred)

    (transformed_src, transformed_trt) = util.derotation(
        src_img, trt_img, perturbed_rotation, fov_gt, FLAGS.transformed_fov,
        [FLAGS.transformed_height, FLAGS.transformed_width], derotate_both)

    (transformed_src_gt, transformed_trt_gt) = util.derotation(
        src_img, trt_img, rotation_gt, fov_gt, FLAGS.transformed_fov,
        [FLAGS.transformed_height, FLAGS.transformed_width], derotate_both)

    half_derotation = util.half_rotation(perturbed_rotation)
    translation_gt = tf.squeeze(
        tf.matmul(half_derotation,
                  tf.expand_dims(translation_gt, -1),
                  transpose_a=True), -1)
    translation_gt = tf.expand_dims(translation_gt, 1)
    distribution_gt = util.spherical_normalization(util.von_mises_fisher(
        translation_gt, tf.constant(FLAGS.kappa, tf.float32),
        [FLAGS.distribution_height, FLAGS.distribution_width]),
                                                   rectify=False)

    pred = net(transformed_src, transformed_trt, training=True)
    directions, expectation, distribution_pred = util.distributions_to_directions(
        pred)

    direction_loss = losses.direction_loss(directions, translation_gt)
    distribution_loss = tf.constant(FLAGS.alpha,
                                    tf.float32) * losses.distribution_loss(
                                        distribution_pred, distribution_gt)
    spread_loss = tf.cast(FLAGS.beta,
                          tf.float32) * losses.spread_loss(expectation)
    direction_error = tf.reduce_mean(
        tf.acos(
            tf.clip_by_value(tf.reduce_sum(directions * translation_gt, -1),
                             -1., 1.)))

    loss = direction_loss + distribution_loss + spread_loss

    tf.summary.scalar('loss', loss)
    tf.summary.scalar('distribution_loss', distribution_loss)
    tf.summary.scalar('spread_loss', spread_loss)
    tf.summary.scalar('direction_error',
                      util.radians_to_degrees(direction_error))

    tf.summary.image('distribution/translation/ground_truth',
                     distribution_gt,
                     max_outputs=4)
    tf.summary.image('distribution/translation/prediction',
                     distribution_pred,
                     max_outputs=4)

    tf.summary.image('source_image', src_img, max_outputs=4)
    tf.summary.image('target_image', trt_img, max_outputs=4)
    tf.summary.image('transformed_source_image',
                     transformed_src,
                     max_outputs=4)
    tf.summary.image('transformed_target_image',
                     transformed_trt,
                     max_outputs=4)
    tf.summary.image('transformed_source_image_gt',
                     transformed_src_gt,
                     max_outputs=4)
    tf.summary.image('transformed_target_image_gt',
                     transformed_trt_gt,
                     max_outputs=4)

    optimizer = tf.train.GradientDescentOptimizer(FLAGS.lr)
    train_op = optimizer.minimize(loss, global_step=global_step, name='train')
    update_op = net.updates
    return Computation(tf.group([train_op, update_op]), loss, global_step)