示例#1
0
文件: tfu3d.py 项目: isarandi/metrabs
def center_relative_pose(coords3d,
                         joint_validity_mask=None,
                         center_is_mean=False,
                         center_joints=None):
    if center_is_mean:
        if isinstance(coords3d, np.ndarray):
            if joint_validity_mask is None:
                center = np.mean(coords3d, axis=1, keepdims=True)
            else:
                coords3d = coords3d.copy()
                coords3d[~joint_validity_mask] = np.nan
                center = np.nanmean(coords3d, axis=1, keepdims=True)
        else:
            if joint_validity_mask is None:
                center = tf.reduce_mean(coords3d, axis=1, keepdims=True)
            else:
                if center_joints is not None:
                    center = tfu.reduce_mean_masked(tf.gather(coords3d,
                                                              center_joints,
                                                              axis=1),
                                                    tf.gather(
                                                        joint_validity_mask,
                                                        center_joints,
                                                        axis=1),
                                                    axis=1,
                                                    keepdims=True)
                else:
                    center = tfu.reduce_mean_masked(coords3d,
                                                    joint_validity_mask,
                                                    axis=1,
                                                    keepdims=True)
    else:
        center = coords3d[:, -1:]
    return coords3d - center
示例#2
0
    def compute_losses(self, inps, preds):
        losses = AttrDict()

        ####################
        # 3D BATCH
        ####################
        coords3d_true_rootrel = tfu3d.center_relative_pose(
            inps.coords3d_true, inps.joint_validity_mask, FLAGS.mean_relative)
        coords3d_pred_rootrel = tfu3d.center_relative_pose(
            preds.coords3d_rel_pred, inps.joint_validity_mask,
            FLAGS.mean_relative)

        rootrel_absdiff = tf.abs(
            (coords3d_true_rootrel - coords3d_pred_rootrel) / 1000)
        losses.loss3d = tfu.reduce_mean_masked(rootrel_absdiff,
                                               inps.joint_validity_mask)

        ####################
        # 2D BATCH
        ####################
        scale_2d = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
        preds.coords2d_pred_2d = models.util.align_2d_skeletons(
            preds.coords2d_pred_2d, inps.coords2d_true_2d,
            inps.joint_validity_mask_2d)
        losses.loss2d = tfu.reduce_mean_masked(
            tf.abs(
                (inps.coords2d_true_2d - preds.coords2d_pred_2d) * scale_2d),
            inps.joint_validity_mask_2d)

        losses.loss = losses.loss3d + FLAGS.loss2d_factor * losses.loss2d
        return losses
示例#3
0
def reconstruct_absolute(coords2d, coords3d_rel, inv_intrinsics):
    coords2d_normalized = model.util.matmul_joint_coords(
        inv_intrinsics, model.util.to_homogeneous(coords2d[:, :, :2]))
    reconstruct_ref = reconstruct_ref_weak if FLAGS.weak_perspective else reconstruct_ref_strong

    is_predicted_to_be_in_fov = tf.reduce_all(tf.logical_and(
        coords2d[:, :, :2] >= tf.cast(FLAGS.stride_train, tf.float32),
        coords2d[:, :, :2] <= tf.cast(FLAGS.proc_side - FLAGS.stride_train,
                                      tf.float32)),
                                              axis=-1)

    ref = reconstruct_ref(coords2d_normalized[:, :, :2], coords3d_rel,
                          is_predicted_to_be_in_fov)
    coords_abs_3d_based = coords3d_rel + tf.expand_dims(ref, 1)

    if FLAGS.metrabs_plus:
        reference_depth = ref[:, 2] + tfu.reduce_mean_masked(
            coords3d_rel[:, :, 2] - coords2d[:, :, 2],
            is_predicted_to_be_in_fov,
            axis=1)
        relative_depths = coords2d[:, :, 2]
    else:
        reference_depth = ref[:, 2]
        relative_depths = coords3d_rel[:, :, 2]

    coords_abs_2d_based = model.util.back_project(coords2d_normalized,
                                                  relative_depths,
                                                  reference_depth)
    is_predicted_to_be_in_fov = tf.tile(
        tf.expand_dims(is_predicted_to_be_in_fov, -1), [1, 1, 3])
    return tf.where(is_predicted_to_be_in_fov, coords_abs_2d_based,
                    coords_abs_3d_based)
示例#4
0
def build_metro_model(joint_info, t):
    if not tfu.is_training():
        return build_metro_inference_model(joint_info, t)

    # Generate predictions for both the 3D and the 2D batch
    if FLAGS.batchnorm_together_2d3d:
        batch_size3d = tfu.dynamic_batch_size(t.x)
        batch_size2d = tfu.dynamic_batch_size(t.x_2d)

        # Concatenate the 3D and the 2D batch
        x_both_batches = tf.concat([t.x, t.x_2d], axis=0)
        coords3d_pred_both = predict_metro(x_both_batches, joint_info)
        # Split the results (3D batch and 2D batch)
        t.coords3d_pred, t.coords3d_pred2d = tf.split(
            coords3d_pred_both, [batch_size3d, batch_size2d])
    else:
        # Send the 2D and the 3D batch separately through the network,
        # so each gets normalized only within itself
        t.coords3d_pred = predict_metro(t.x, joint_info)
        t.coords3d_pred2d = predict_metro(t.x_2d, joint_info)

    # Loss for 3D batch
    t.coords3d_true_rootrel = tfu3d.center_relative_pose(
        t.coords3d_true, t.joint_validity_mask, FLAGS.mean_relative)
    t.coords3d_pred_rootrel = tfu3d.center_relative_pose(
        t.coords3d_pred, t.joint_validity_mask, FLAGS.mean_relative)

    rootrel_absdiff = tf.abs(t.coords3d_true_rootrel - t.coords3d_pred_rootrel)
    t.loss3d = tfu.reduce_mean_masked(rootrel_absdiff,
                                      t.joint_validity_mask) / 1000

    ## Loss for 2D batch
    # Pick out the joints that correspond to the ones labeled in the 2D dataset
    joint_info_2d = data.datasets2d.get_dataset(FLAGS.dataset2d).joint_info
    joint_ids_3d = [joint_info.ids[name] for name in joint_info_2d.names]
    t.coords32d_pred2d = tf.gather(t.coords3d_pred2d, joint_ids_3d,
                                   axis=1)[..., :2]

    scale_factor2d = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
    t.coords32d_pred2d = model.util.align_2d_skeletons(t.coords2d_true2d,
                                                       t.coords32d_pred2d,
                                                       t.joint_validity_mask2d)
    t.loss2d = tfu.reduce_mean_masked(
        tf.abs(t.coords2d_true2d - t.coords32d_pred2d),
        t.joint_validity_mask2d) * scale_factor2d

    t.loss = t.loss3d + FLAGS.loss2d_factor * t.loss2d
示例#5
0
def compute_pose3d_metrics(inps, preds):
    metrics = AttrDict()
    coords3d_pred = (
        preds.coords3d_pred_abs if 'coords3d_pred_abs' in preds else preds.coords3d_rel_pred)
    rootrelative_diff = tfu3d.center_relative_pose(
        coords3d_pred - inps.coords3d_true, inps.joint_validity_mask,
        center_is_mean=FLAGS.mean_relative)
    dist = tf.norm(rootrelative_diff, axis=-1)
    metrics.mean_error = tfu.reduce_mean_masked(dist, inps.joint_validity_mask)
    if 'coords3d_pred_abs' in preds:
        metrics.mean_error_abs = tfu.reduce_mean_masked(
            tf.norm(inps.coords3d_true - preds.coords3d_pred_abs, axis=-1),
            inps.joint_validity_mask)

    if 'coords2d_pred' in preds:
        metrics.mean_error_2d = tfu.reduce_mean_masked(
            tf.norm(inps.coords2d_true - preds.coords2d_pred[:, :, :2], axis=-1),
            inps.joint_validity_mask)

    coords3d_pred_procrustes = tfu3d.rigid_align(
        coords3d_pred, inps.coords3d_true,
        joint_validity_mask=inps.joint_validity_mask, scale_align=True)

    dist_procrustes = tf.norm(coords3d_pred_procrustes - inps.coords3d_true, axis=-1)
    metrics.mean_error_procrustes = tfu.reduce_mean_masked(
        dist_procrustes, inps.joint_validity_mask)

    j = data.datasets3d.get_dataset(FLAGS.dataset).joint_info.ids
    threshold = np.float32(150)

    auc_score = tfu.auc(dist, 0, threshold)
    metrics.mean_auc = tfu.reduce_mean_masked(auc_score, inps.joint_validity_mask)

    is_correct = tf.cast(dist <= threshold, tf.float32)
    metrics.mean_pck = tfu.reduce_mean_masked(is_correct, inps.joint_validity_mask)

    if 'lwri' in j and 'rwri' in j:
        all_wrists = [idx for name, idx in j.items() if 'lwri' in name or 'rwri' in name]
        metrics.pck_wrists = tfu.reduce_mean_masked(
            tf.gather(is_correct, all_wrists, axis=1),
            tf.gather(inps.joint_validity_mask, all_wrists, axis=1))
        metrics.auc_wrists = tfu.reduce_mean_masked(
            tf.gather(auc_score, all_wrists, axis=1),
            tf.gather(inps.joint_validity_mask, all_wrists, axis=1))
    else:
        metrics.auc_wrists = tf.constant(0)
        metrics.pck_wrists = tf.constant(0)

    masked_dist_pa = tf.where(inps.joint_validity_mask, dist_procrustes, tf.cast(0, tf.float32))
    max_dist_pa = tf.reduce_max(masked_dist_pa, axis=1)
    metrics.ncps_auc = tf.reduce_mean(tfu.auc(max_dist_pa, 50, 150))
    metrics.ncps = tf.reduce_mean(tf.cast(max_dist_pa <= threshold, tf.float32))
    return metrics
示例#6
0
def build_eval_metrics(t):
    rootrelative_diff = tfu3d.root_relative(t.coords3d_pred - t.coords3d_true)
    dist = tf.norm(rootrelative_diff, axis=-1)
    t.mean_error = tfu.reduce_mean_masked(dist, t.joint_validity_mask)
    t.coords3d_pred_procrustes = tfu3d.rigid_align(
        t.coords3d_pred,
        t.coords3d_true,
        joint_validity_mask=t.joint_validity_mask,
        scale_align=True)

    rootrelative_diff_procrust = tfu3d.root_relative(
        t.coords3d_pred_procrustes - t.coords3d_true)
    dist_procrustes = tf.norm(rootrelative_diff_procrust, axis=-1)
    t.mean_error_procrustes = tfu.reduce_mean_masked(dist_procrustes,
                                                     t.joint_validity_mask)

    threshold = np.float32(150)
    auc_score = tf.maximum(np.float32(0), 1 - dist / threshold)
    t.auc_for_nms = tfu.reduce_mean_masked(auc_score,
                                           t.joint_validity_mask,
                                           axis=0)
    t.mean_auc = tfu.reduce_mean_masked(auc_score, t.joint_validity_mask)

    is_correct = tf.cast(dist <= threshold, tf.float32)
    t.pck = tfu.reduce_mean_masked(is_correct, t.joint_validity_mask, axis=0)
    t.mean_pck = tfu.reduce_mean_masked(is_correct, t.joint_validity_mask)
示例#7
0
def reconstruct_ref_weak(normalized_2d, coords3d_rel, validity_mask):
    mean3d, stdev3d = tfu.mean_stdev_masked(coords3d_rel[..., :2],
                                            validity_mask,
                                            items_axis=1,
                                            dimensions_axis=2)

    mean2d, stdev2d = tfu.mean_stdev_masked(normalized_2d[..., :2],
                                            validity_mask,
                                            items_axis=1,
                                            dimensions_axis=2)

    stdev2d = tf.maximum(stdev2d, 1e-5)
    stdev3d = tf.maximum(stdev3d, 1e-5)

    old_mean = tfu.reduce_mean_masked(coords3d_rel,
                                      validity_mask,
                                      axis=1,
                                      keepdims=True)
    new_mean_z = tf.math.divide_no_nan(stdev3d, stdev2d)
    new_mean = model.util.to_homogeneous(mean2d) * new_mean_z
    return tf.squeeze(new_mean - old_mean, 1)
示例#8
0
    def compute_losses(self, inps, preds):
        losses = AttrDict()

        if FLAGS.scale_agnostic_loss:
            mean_true, scale_true = tfu.mean_stdev_masked(
                inps.coords3d_true,
                inps.joint_validity_mask,
                items_axis=1,
                dimensions_axis=2)
            mean_pred, scale_pred = tfu.mean_stdev_masked(
                preds.coords3d_rel_pred,
                inps.joint_validity_mask,
                items_axis=1,
                dimensions_axis=2)
            coords3d_pred_rootrel = tf.math.divide_no_nan(
                preds.coords3d_rel_pred - mean_pred, scale_pred) * scale_true
            coords3d_true_rootrel = inps.coords3d_true - mean_true
        else:
            coords3d_true_rootrel = tfu3d.center_relative_pose(
                inps.coords3d_true, inps.joint_validity_mask,
                FLAGS.mean_relative)
            coords3d_pred_rootrel = tfu3d.center_relative_pose(
                preds.coords3d_rel_pred, inps.joint_validity_mask,
                FLAGS.mean_relative)

        rootrel_absdiff = tf.abs(
            (coords3d_true_rootrel - coords3d_pred_rootrel) / 1000)
        losses.loss3d = tfu.reduce_mean_masked(rootrel_absdiff,
                                               inps.joint_validity_mask)

        if FLAGS.scale_agnostic_loss:
            _, scale_true = tfu.mean_stdev_masked(inps.coords3d_true,
                                                  inps.joint_validity_mask,
                                                  items_axis=1,
                                                  dimensions_axis=2,
                                                  fixed_ref=tf.zeros_like(
                                                      inps.coords3d_true))
            _, scale_pred = tfu.mean_stdev_masked(preds.coords3d_pred_abs,
                                                  inps.joint_validity_mask,
                                                  items_axis=1,
                                                  dimensions_axis=2,
                                                  fixed_ref=tf.zeros_like(
                                                      inps.coords3d_true))
            preds.coords3d_pred_abs = tf.math.divide_no_nan(
                preds.coords3d_pred_abs, scale_pred) * scale_true

        if self.global_step > 5000:
            absdiff = tf.abs(
                (inps.coords3d_true - preds.coords3d_pred_abs) / 1000)
            losses.loss3d_abs = tfu.reduce_mean_masked(
                absdiff, inps.joint_validity_mask)
        else:
            losses.loss3d_abs = tf.constant(0, tf.float32)

        scale_2d = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
        losses.loss23d = tfu.reduce_mean_masked(
            tf.abs((inps.coords2d_true - preds.coords2d_pred) * scale_2d),
            inps.joint_validity_mask)

        preds.coords32d_pred_2d = models.util.align_2d_skeletons(
            preds.coords32d_pred_2d, inps.coords2d_true_2d,
            inps.joint_validity_mask_2d)
        losses.loss32d = tfu.reduce_mean_masked(
            tf.abs(
                (inps.coords2d_true_2d - preds.coords32d_pred_2d) * scale_2d),
            inps.joint_validity_mask_2d)
        losses.loss22d = tfu.reduce_mean_masked(
            tf.abs(
                (inps.coords2d_true_2d - preds.coords22d_pred_2d) * scale_2d),
            inps.joint_validity_mask_2d)

        losses3d = [
            losses.loss3d, losses.loss23d,
            FLAGS.absloss_factor * losses.loss3d_abs
        ]
        losses2d = [losses.loss22d, losses.loss32d]
        losses.loss = tf.add_n(
            losses3d) + FLAGS.loss2d_factor * tf.add_n(losses2d)
        return losses
示例#9
0
def build_25d_model(joint_info, t):
    if not tfu.is_training():
        return build_25d_inference_model(joint_info, t)

    if FLAGS.batchnorm_together_2d3d:
        batch_size3d = tfu.dynamic_batch_size(t.x)
        batch_size2d = tfu.dynamic_batch_size(t.x_2d)

        # Concatenate the 3D and the 2D batch
        x_both_batches = tf.concat([t.x, t.x_2d], axis=0)
        coords25d_pred_both = predict_25d(x_both_batches, joint_info)
        # Split the results (3D batch and 2D batch)
        t.coords25d_pred, t.coords25d_pred_2d = tf.split(
            coords25d_pred_both, [batch_size3d, batch_size2d])
    else:
        # Send the 2D and the 3D batch separately through the network,
        # so each gets normalized only within itself
        t.coords25d_pred = predict_25d(t.x, joint_info)
        t.coords25d_pred_2d = predict_25d(t.x_2d, joint_info)

    if FLAGS.dataset == 'mpi_inf_3dhp':
        t.coords25d_pred_2d = model.util.adjust_skeleton_3dhp_to_mpii(
            t.coords25d_pred_2d, joint_info)

    joint_info_2d = data.datasets2d.get_dataset(FLAGS.dataset2d).joint_info
    joint_ids_3d = [joint_info.ids[name] for name in joint_info_2d.names]
    t.coords2d_pred2d = tf.gather(t.coords25d_pred_2d[..., :2],
                                  joint_ids_3d,
                                  axis=1)
    t.coords2d_pred = t.coords25d_pred[..., :2]

    # LOSS 3D BATCH
    scale2d = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
    t.loss2d_3d = tf.reduce_mean(
        tf.abs(t.coords2d_true - t.coords2d_pred)) * scale2d

    z_ref = t.coords3d_true[...,
                            2] - t.coords3d_true[:, -1:,
                                                 2] + 0.5 * FLAGS.box_size_mm
    t.loss_z = tf.reduce_mean(tf.abs(z_ref - t.coords25d_pred[..., 2])) / 1000

    # LOSS 2D BATCH
    t.loss2d = tfu.reduce_mean_masked(
        tf.abs(t.coords2d_true2d - t.coords2d_pred2d),
        t.joint_validity_mask2d) * scale2d

    t.loss3d = (t.loss2d_3d * 2 + t.loss_z) / 3
    t.loss = t.loss3d + FLAGS.loss2d_factor * t.loss2d

    # POST-PROCESSING
    if FLAGS.bone_length_dataset:
        dataset = data.datasets3d.get_dataset(FLAGS.bone_length_dataset)
    else:
        dataset = data.datasets3d.get_dataset(FLAGS.dataset)

    delta_z_pred = t.coords25d_pred[..., 2] - t.coords25d_pred[:, -1:, 2]
    if FLAGS.train_on == 'trainval':
        target_bone_lengths = dataset.trainval_bones
    else:
        target_bone_lengths = dataset.train_bones

    camcoords2d_homog = model.util.matmul_joint_coords(
        t.inv_intrinsics, model.util.to_homogeneous(t.coords2d_pred))
    z_offset = optimize_z_offset_by_bones(camcoords2d_homog, delta_z_pred,
                                          target_bone_lengths,
                                          joint_info.stick_figure_edges)
    t.coords3d_pred = model.util.back_project(camcoords2d_homog, delta_z_pred,
                                              z_offset)
示例#10
0
def build_metrabs_model(joint_info, t):
    if not tfu.is_training():
        return build_metrabs_inference_model(joint_info, t)

    # Generate predictions for both the 3D and the 2D batch
    if FLAGS.batchnorm_together_2d3d:
        batch_size3d = tfu.dynamic_batch_size(t.x)
        batch_size2d = tfu.dynamic_batch_size(t.x_2d)

        # Concatenate the 3D and the 2D batch
        x_both_batches = tf.concat([t.x, t.x_2d], axis=0)
        coords2d_pred_both, coords3d_rel_pred_both = predict_heads_metrabs(
            x_both_batches, joint_info)
        # Split the results (3D batch and 2D batch)
        t.coords2d_pred, t.coords2d_pred2d = tf.split(
            coords2d_pred_both, [batch_size3d, batch_size2d])
        t.coords3d_rel_pred, t.coords3d_pred2d = tf.split(
            coords3d_rel_pred_both, [batch_size3d, batch_size2d])
    else:
        # Send the 2D and the 3D batch separately through the network,
        # so each gets normalized only within itself
        t.coords2d_pred, t.coords3d_rel_pred = predict_heads_metrabs(
            t.x, joint_info)
        t.coords2d_pred2d, t.coords3d_pred2d = predict_heads_metrabs(
            t.x_2d, joint_info)

    # Reconstruct absolute pose only on the 3D batch
    t.coords3d_pred = tf.cond(
        t.global_step > 50, lambda: reconstruct_absolute(
            t.coords2d_pred, t.coords3d_rel_pred, t.inv_intrinsics),
        lambda: t.coords3d_rel_pred)

    ######
    # LOSSES FOR 3D BATCH
    ######
    #
    # Loss on 2D head for 3D batch
    scale_factor2d = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
    t.loss23d = tfu.reduce_mean_masked(
        tf.abs(t.coords2d_true - t.coords2d_pred),
        t.joint_validity_mask) * scale_factor2d

    # Loss on 3D head (relative) for 3D batch
    t.coords3d_true_rootrel = tfu3d.center_relative_pose(
        t.coords3d_true, t.joint_validity_mask, FLAGS.mean_relative)
    t.coords3d_pred_rootrel = tfu3d.center_relative_pose(
        t.coords3d_rel_pred, t.joint_validity_mask, FLAGS.mean_relative)
    rootrel_absdiff = tf.abs(t.coords3d_true_rootrel - t.coords3d_pred_rootrel)
    t.loss3d = tfu.reduce_mean_masked(rootrel_absdiff,
                                      t.joint_validity_mask) / 1000

    # Loss on absolute reconstruction for 3D batch
    absloss_factor = tf.where(
        t.global_step > 5000,
        tf.convert_to_tensor(FLAGS.absloss_factor, dtype=tf.float32),
        tf.convert_to_tensor(0, dtype=tf.float32))
    absdiff = tf.abs(t.coords3d_true - t.coords3d_pred)
    t.loss3d_abs = absloss_factor * tfu.reduce_mean_masked(
        absdiff, t.joint_validity_mask) / 1000
    losses3d = [t.loss3d, t.loss23d, absloss_factor * t.loss3d_abs]

    ######
    # LOSSES FOR 2D BATCH
    ######
    #
    # Pick out the joints that correspond to the ones labeled in the 2D dataset
    joint_info_2d = data.datasets2d.get_dataset(FLAGS.dataset2d).joint_info
    joint_ids_3d = [joint_info.ids[name] for name in joint_info_2d.names]
    t.coords32d_pred2d = tf.gather(t.coords3d_pred2d, joint_ids_3d,
                                   axis=1)[..., :2]
    t.coords22d_pred2d = tf.gather(t.coords2d_pred2d, joint_ids_3d,
                                   axis=1)[..., :2]

    # Loss on 2D head for 2D batch
    t.loss22d = tfu.reduce_mean_masked(
        tf.abs(t.coords2d_true2d - t.coords22d_pred2d),
        t.joint_validity_mask2d) * scale_factor2d

    # Loss on 3D head for 2D batch
    t.coords32d_pred2d = model.util.align_2d_skeletons(t.coords2d_true2d,
                                                       t.coords32d_pred2d,
                                                       t.joint_validity_mask2d)
    t.loss32d = tfu.reduce_mean_masked(
        tf.abs(t.coords2d_true2d - t.coords32d_pred2d),
        t.joint_validity_mask2d) * scale_factor2d
    losses2d = [t.loss22d, t.loss32d]

    t.loss = tf.add_n(losses3d) + FLAGS.loss2d_factor * tf.add_n(losses2d)
示例#11
0
def pck(pred, labels, joint_validity_mask, reference_size,
        max_relative_distance):
    rel_dist = relative_distance(pred, labels, reference_size)
    is_correct = tf.cast(rel_dist <= max_relative_distance, tf.float32)
    return tfu.reduce_mean_masked(is_correct, joint_validity_mask, axis=0)
示例#12
0
def auc(pred, labels, joint_validity_mask, reference_size,
        max_relative_distance):
    rel_dist = relative_distance(pred, labels, reference_size)
    score = tf.maximum(rel_dist, max_relative_distance)
    return tfu.reduce_mean_masked(score, joint_validity_mask, axis=0)
示例#13
0
def build_25d_model(joint_info, learning_phase, t, reuse=None):
    if learning_phase != TRAIN:
        return build_inference_model(joint_info, learning_phase, t, reuse=reuse)

    with tf.name_scope(None, 'Prediction'):
        depth = FLAGS.depth

        def im2pred(im, reuse=reuse):
            net_output = model.architectures.resnet(
                im, n_out=depth * joint_info.n_joints, scope='MainPart', reuse=reuse,
                stride=FLAGS.stride_train, centered_stride=FLAGS.centered_stride,
                resnet_name=FLAGS.architecture)

            side = tfu.static_image_shape(net_output)[0]
            net_output_nchw = tfu.std_to_nchw(net_output)
            reshaped = tf.reshape(net_output_nchw, [-1, depth, joint_info.n_joints, side, side])
            logits = tf.transpose(reshaped, [0, 2, 3, 4, 1])
            softmaxed = tfu.softmax(logits, axis=[2, 3, 4])
            coords3d = tf.stack(tfu.decode_heatmap(softmaxed, [3, 2, 4]), axis=-1)
            return logits, softmaxed, coords3d

        # PREDICT FOR 3D BATCH
        logits, t.softmaxed, coords3d = im2pred(t.x)
        t.coords3d_pred = heatmap_to_25d(coords3d, learning_phase)

        # PREDICT FOR 2D BATCH
        if FLAGS.train_mixed:
            t.coords3d_pred_2d = heatmap_to_25d(im2pred(t.x_2d)[-1], learning_phase)
            if FLAGS.dataset == 'mpi_inf_3dhp':
                t.coords3d_pred_2d = adjust_skeleton_3dhp_to_mpii(t.coords3d_pred_2d, joint_info)

            joint_info_2d = data.datasets2d.get_dataset(FLAGS.dataset2d).joint_info
            joint_ids_3d = [joint_info.ids[name] for name in joint_info_2d.names]
            t.coords2d_pred_2d = tf.gather(t.coords3d_pred_2d, joint_ids_3d, axis=1)[..., :2]

        # LOSS 3D BATCH
        scale = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000
        t.loss2d_3d = tf.reduce_mean(tf.abs(t.coords2d_true - t.coords3d_pred[..., :2])) * scale
        z_ref = t.coords3d_true[..., 2] - t.coords3d_true[:, -1:, 2] + 0.5 * FLAGS.box_size_mm
        t.loss_z = tf.reduce_mean(tf.abs(z_ref - t.coords3d_pred[..., 2])) / 1000

        # LOSS 2D BATCH
        if FLAGS.train_mixed:
            t.loss2d = tfu.reduce_mean_masked(
                tf.abs(t.coords2d_true_2d - t.coords2d_pred_2d), t.joint_validity_mask_2d) * scale
        else:
            t.loss2d = 0

        t.loss3d = (t.loss2d_3d * 2 + t.loss_z) / 3
        t.loss = FLAGS.loss2d_factor * t.loss2d + t.loss3d

        # POST-PROCESSING
        if FLAGS.bone_length_dataset:
            dataset = data.datasets.get_dataset(FLAGS.bone_length_dataset)
        else:
            dataset = data.datasets.current_dataset()

        im_pred2d = t.coords3d_pred[..., :2]
        im_pred2d_homog = to_homogeneous_coords(im_pred2d)
        camcoords2d_homog = matmul_joint_coords(t.inv_intrinsics, im_pred2d_homog)
        delta_z_pred = t.coords3d_pred[..., 2] - t.coords3d_pred[:, -1:, 2]

        if FLAGS.train_on == 'trainval':
            target_bone_lengths = dataset.trainval_bones
        else:
            target_bone_lengths = dataset.train_bones

        z_offset = optimize_z_offset_by_bones(
            camcoords2d_homog, delta_z_pred, target_bone_lengths, joint_info.stick_figure_edges)
        t.coords3d_pred = back_project(camcoords2d_homog, delta_z_pred, z_offset)

    t.coords3d_pred_orig_cam = to_orig_cam(t.coords3d_pred, t.rot_to_orig_cam, joint_info)
    t.coords3d_true_orig_cam = to_orig_cam(t.coords3d_true, t.rot_to_orig_cam, joint_info)
示例#14
0
def build_metro_model(joint_info, learning_phase, t, reuse=None):
    if learning_phase != TRAIN:
        return build_inference_model(joint_info, learning_phase, t, reuse=reuse)

    with tf.name_scope(None, 'Prediction'):
        depth = FLAGS.depth

        def im2pred(im, reuse=reuse):
            net_output = model.architectures.resnet(
                im, n_out=depth * joint_info.n_joints, scope='MainPart', reuse=reuse,
                stride=FLAGS.stride_train, centered_stride=FLAGS.centered_stride,
                resnet_name=FLAGS.architecture)

            net_output_nchw = tfu.std_to_nchw(net_output)
            side = tfu.static_image_shape(net_output)[0]
            reshaped = tf.reshape(net_output_nchw, [-1, depth, joint_info.n_joints, side, side])
            logits = tf.transpose(reshaped, [0, 2, 3, 4, 1])
            softmaxed = tfu.softmax(logits, axis=[2, 3, 4])
            coords3d = tf.stack(tfu.decode_heatmap(softmaxed, [3, 2, 4]), axis=-1)
            return logits, softmaxed, coords3d

        # PREDICT FOR 3D BATCH
        t.logits, t.softmaxed, coords3d = im2pred(t.x)
        t.coords3d_pred = heatmap_to_metric(coords3d, learning_phase)
        if FLAGS.train_mixed:
            coords3d_2d = im2pred(t.x_2d)[-1]
            t.coords3d_pred_2d = heatmap_to_metric(coords3d_2d, learning_phase)

        # PREDICT FOR 2D BATCH
        if FLAGS.train_mixed:
            if FLAGS.dataset == 'mpi_inf_3dhp':
                t.coords3d_pred_2d = adjust_skeleton_3dhp_to_mpii(t.coords3d_pred_2d, joint_info)

            joint_info_2d = data.datasets2d.get_dataset(FLAGS.dataset2d).joint_info
            joint_ids_3d = [joint_info.ids[name] for name in joint_info_2d.names]
            t.coords2d_pred_2d = tf.gather(t.coords3d_pred_2d, joint_ids_3d, axis=1)[..., :2]

        # LOSS 3D BATCH
        t.coords3d_true_rootrel = tfu3d.root_relative(t.coords3d_true)
        t.coords3d_pred_rootrel = tfu3d.root_relative(t.coords3d_pred)

        absdiff = tf.abs(t.coords3d_true_rootrel - t.coords3d_pred_rootrel)
        t.loss3d = tfu.reduce_mean_masked(absdiff, t.joint_validity_mask) / 1000

        # LOSS 2D BATCH
        if FLAGS.train_mixed:
            t.coords2d_true_2d_scaled, t.coords2d_pred_2d = align_2d_skeletons(
                t.coords2d_true_2d, t.coords2d_pred_2d, t.joint_validity_mask_2d)
            scale = 1 / FLAGS.proc_side * FLAGS.box_size_mm / 1000

            t.loss2d = tfu.reduce_mean_masked(
                tf.abs(t.coords2d_true_2d_scaled - t.coords2d_pred_2d),
                t.joint_validity_mask_2d) * scale
        else:
            t.loss2d = 0

        loss2d_factor = FLAGS.loss2d_factor
        t.loss = t.loss3d + loss2d_factor * t.loss2d

    t.coords3d_pred_orig_cam = to_orig_cam(t.coords3d_pred, t.rot_to_orig_cam, joint_info)
    t.coords3d_true_orig_cam = to_orig_cam(t.coords3d_true, t.rot_to_orig_cam, joint_info)