def predict_metro(im, joint_info): stride = FLAGS.stride_train if tfu.is_training() else FLAGS.stride_test # 1. Feed image through backbone logits = model.architectures.resnet( im, n_outs=[FLAGS.depth * joint_info.n_joints], scope='MainPart', reuse=tf.compat.v1.AUTO_REUSE, stride=stride, centered_stride=FLAGS.centered_stride, resnet_name=FLAGS.architecture)[0] logits = tfu.std_to_nchw(logits) side = tfu.static_shape(logits)[2] # 2. Reshape the 3D heatmap logits to actually be 3D: [batch, joints, H, W, D] logits = tf.reshape(logits, [-1, FLAGS.depth, joint_info.n_joints, side, side]) logits = tf.transpose(logits, [0, 2, 3, 4, 1]) # 3. Decode the heatmap coordinates using soft-argmax, resulting in values between 0 and 1 coords3d_raw = tfu.soft_argmax(logits, [3, 2, 4]) # 4. Scale and shift the normalized heatmap coordinates to get metric and pixel values coords3d_pred = model.util.heatmap_to_metric(coords3d_raw) return coords3d_pred
def optimize_z_offset_by_bones(xs, delta_zs, bone_lengths_ideal, edges): def fun(xs_, delta_zs_): return np.array([ optimize_z_offset_by_bones_single(x, delta_z, bone_lengths_ideal, edges) for x, delta_z in zip(xs_, delta_zs_)], dtype=np.float32) batch_size = tfu.static_shape(xs)[0] return tfu.py_func_with_shapes( fun, [xs, delta_zs], output_types=(np.float32,), output_shapes=([batch_size],))[0]
def spatial_slice(inp, sl1, sl2=None): """Returns inp[:, sl1, sl2, :, :] such that the 'offset:' is done on the image (spatial) axes""" if sl2 is None: sl2 = sl1 image_axes = tfu.image_axes() ndims = len(tfu.static_shape(inp)) indices = [slice(None)] * ndims indices[image_axes[0]] = sl1 indices[image_axes[1]] = sl2 return inp[indices]
def transform_coords(coords, n_output_joints): """Learned linear combination of joints positions, either for 2D or 3D""" def normalize_weights(w): return w / tf.reduce_sum(w, axis=0, keepdims=True) n_input_joints = tfu.static_shape(coords)[1] initializer = tf.constant_initializer(np.eye(n_input_joints), dtype=coords.dtype) weights = tf.get_variable(name='skeleton_weights', shape=(n_input_joints, n_output_joints), dtype=coords.dtype, initializer=initializer, trainable=True, regularizer=None, constraint=normalize_weights) return tf.einsum('bjc,jJ->bJc', coords, normalize_weights(weights))
def rigid_align(coords_pred, coords_true, *, joint_validity_mask=None, scale_align=False): def func(_coords_pred, _coords_true, _joint_validity_mask): return util3d.rigid_align_many( _coords_pred, _coords_true, joint_validity_mask=_joint_validity_mask, scale_align=scale_align) if joint_validity_mask is None: joint_validity_mask = tf.ones_like(coords_pred[..., 0], dtype=tf.bool) return tfu.py_func_with_shapes( func=func, inp=[coords_pred, coords_true, joint_validity_mask], output_types=tf.float32, output_shapes=tfu.static_shape(coords_pred))
def reconstruct_ref_strong(normalized_2d, coords3d_rel, validity_mask): """Reconstructs the reference point location. Args: normalized_2d: normalized image coordinates of the joints (without intrinsics applied), shape [batch_size, n_points, 2] coords3d_rel: 3D camera coordinate offsets relative to the unknown reference point which we want to reconstruct, shape [batch_size, n_points, 3] validity_mask: boolean mask of shape [batch_size, n_points] containing True where the point is reliable and should be used in the reconstruction Returns: The 3D reference point in camera coordinates, shape [batch_size, 3] """ def root_mean_square(x): return tf.sqrt(tf.reduce_mean(tf.square(x))) n_batch = tfu.dynamic_batch_size(normalized_2d) n_points = tfu.static_shape(normalized_2d)[1] eyes = tf.tile(tf.expand_dims(tf.eye(2, 2), 0), [n_batch, n_points, 1]) reshaped2d = tf.reshape(normalized_2d, [-1, n_points * 2, 1]) scale2d = root_mean_square(reshaped2d) A = tf.concat([eyes, -reshaped2d / scale2d], axis=2) rel_backproj = normalized_2d * coords3d_rel[:, :, 2:] - coords3d_rel[:, :, :2] scale_rel_backproj = root_mean_square(rel_backproj) b = tf.reshape(rel_backproj / scale_rel_backproj, [-1, n_points * 2, 1]) weights = tf.cast(validity_mask, tf.float32) + np.float32(1e-4) weights = tf.reshape(tf.tile(tf.expand_dims(weights, -1), [1, 1, 2]), [-1, n_points * 2, 1]) ref = tf.linalg.lstsq(A * weights, b * weights, fast=True) ref = tf.concat([ref[:, :2], ref[:, 2:] / scale2d], axis=1) * scale_rel_backproj return ref