def _override_keypoint_masks(keypoint_masks, keypoint_profile, part_names,
                             overriding_func):
    """Overrides keypoint masks by part.

  Args:
    keypoint_masks: A tensor for input keypoint masks.
    keypoint_profile: A KeypointProfile object for keypoints.
    part_names: A list of standard names of parts of which the masks are
      overridden. See `KeypointProfile.get_standard_part_index` for standard
      part names.
    overriding_func: A function that returns overriding tensors.

  Returns:
    keypoint_masks: A tensor for output keypoint masks.

  """
    part_indices = []
    for name in part_names:
        part_indices.extend(keypoint_profile.get_standard_part_index(name))
    part_indices = list(set(part_indices))
    keypoint_masks = data_utils.update_sub_tensor(keypoint_masks,
                                                  indices=part_indices,
                                                  axis=-1,
                                                  update_func=overriding_func)
    return keypoint_masks
示例#2
0
    def test_update_sub_tensor(self):
        # Shape = [3, 5, 2].
        x = tf.constant([
            [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]],
            [[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0],
             [18.0, 19.0]],
            [[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0],
             [28.0, 29.0]],
        ])

        def update_func(sub_tensor):
            # Shape = [3, 3, 2].
            delta = tf.constant([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
                                 [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]],
                                 [[1.3, 1.4], [1.5, 1.6], [1.7, 1.8]]])
            return sub_tensor + delta

        updated_x = data_utils.update_sub_tensor(x,
                                                 indices=[0, 2, 4],
                                                 axis=-2,
                                                 update_func=update_func)

        self.assertAllClose(updated_x, [
            [[0.1, 1.2], [2.0, 3.0], [4.3, 5.4], [6.0, 7.0], [8.5, 9.6]],
            [[10.7, 11.8], [12.0, 13.0], [14.9, 16.0], [16.0, 17.0],
             [19.1, 20.2]],
            [[21.3, 22.4], [22.0, 23.0], [25.5, 26.6], [26.0, 27.0],
             [29.7, 30.8]],
        ])
def override_points(points, from_indices_list, to_indices):
    """Overrides points with other points.

  Points at `to_indices` will be overridden with centers of points from
  `from_indices_list`.

  For example:

    from_indices_list = [[0, 1], [2]]
    to_indices = [3, 4]
    updated_points = override_points(from_indices_list, to_indices)

  Will result in:
    updated_points[..., 3, :] ==
      ((points[..., 0, :] + points[..., 1, :]) / 2 + points[..., 2, :]) / 2
    updated_points[..., 4, :] ==
      ((points[..., 0, :] + points[..., 1, :]) / 2 + points[..., 2, :]) / 2

  Args:
    points: A tensor for points to override. Shape = [..., num_points,
      point_dim].
    from_indices_list: A list of integer lists for point indices to compute
      overriding points.
    to_indices: A list of integers for point indices to be overridden.

  Returns:
    A tensor for updated points.
  """
    overriding_points = [
        get_points(points, from_indices) for from_indices in from_indices_list
    ]
    overriding_points = tf.concat(overriding_points, axis=-2)
    overriding_points = tf.math.reduce_mean(overriding_points,
                                            axis=-2,
                                            keepdims=True)
    overriding_points = data_utils.tile_last_dims(
        overriding_points, last_dim_multiples=[len(to_indices), 1])
    return data_utils.update_sub_tensor(
        points,
        indices=to_indices,
        axis=-2,
        update_func=lambda _: overriding_points)
def create_model_input(keypoints_2d,
                       keypoint_masks_2d,
                       keypoints_3d,
                       model_input_keypoint_type,
                       model_input_keypoint_mask_type=(
                           common.MODEL_INPUT_KEYPOINT_MASK_TYPE_NO_USE),
                       normalize_keypoints_2d=True,
                       keypoint_profile_2d=None,
                       uniform_keypoint_jittering_max_offset_2d=0.0,
                       gaussian_keypoint_jittering_offset_stddev_2d=0.0,
                       keypoint_dropout_probs=(0.0, 0.0),
                       structured_keypoint_mask_processor=None,
                       set_on_mask_for_non_anchors=False,
                       mix_mask_sub_batches=False,
                       rescale_features=False,
                       forced_mask_on_part_names=None,
                       forced_mask_off_part_names=None,
                       keypoint_profile_3d=None,
                       azimuth_range=(-math.pi, math.pi),
                       elevation_range=(-math.pi / 6.0, math.pi / 6.0),
                       roll_range=(-math.pi / 6.0, math.pi / 6.0),
                       normalized_camera_depth_range=(),
                       sequential_inputs=False,
                       seed=None):
    """Creates model input features from input data.

  Args:
    keypoints_2d: A tensor for input 2D keyopints. Shape = [..., num_keypoints,
      2]. Use None if irrelevant.
    keypoint_masks_2d: A tensor for input 2D keypoint masks. Shape = [...,
      num_keypoints]. Use None if irrelevant.
    keypoints_3d: A tensor for input 3D keyopints. Shape = [..., num_keypoints,
      3]. Use None if irrelevant.
    model_input_keypoint_type: An enum string for model input type. See
      `MODEL_INPUT_TYPE_*` for supported values.
    model_input_keypoint_mask_type: An enum string for model input keypoint mask
      type. See `MODEL_INPUT_KEYPOINT_MASK_TYPE_*` for supported values.
    normalize_keypoints_2d: A boolean for whether to normalize 2D keypoints at
      the end.
    keypoint_profile_2d: A KeypointProfile2D object for input 2D keypoints.
      Required for normalizing 2D keypoints, 3D-to-2D projection, or forcing
      masks on/off.
    uniform_keypoint_jittering_max_offset_2d: A float for maximum 2D keypoint
      jittering offset. Random jittering offset within
      [-uniform_keypoint_jittering_max_offset_2d,
      uniform_keypoint_jittering_max_offset_2d] is to be added to each keypoint
      2D. Note that the jittering happens after the 2D normalization. Ignored if
      non-positive.
    gaussian_keypoint_jittering_offset_stddev_2d: A float for standard deviation
      of Gaussian 2D keypoint jittering offset. Random jittering offset sampled
      from N(0, gaussian_keypoint_jittering_offset_stddev_2d) is to be added to
      each keypoint. Note that the jittering happens after the 2D normalization.
      Ignored if non-positive.
    keypoint_dropout_probs: A tuple of floats for the keypoint random dropout
      probabilities in the format (probability_to_apply, probability_to_drop).
      We perform stratified dropout as first select instances with
      `probability_to_apply` and then drop their keypoints with
      `probability_to_drop`. When sequential_input is True, there might be a
      third element indicating the probability of using sequence-level dropout.
      Only used when keypoint scores are relevant.
    structured_keypoint_mask_processor: A Python function for generating
      keypoint masks with structured dropout. Ignored if None.
    set_on_mask_for_non_anchors: A boolean for whether to always use on (1)
      masks for non-anchor samples. We assume the second from the left tensor
      dimension is for anchor/non-anchor, and the non-anchor samples start at
      the second element along that dimension.
    mix_mask_sub_batches: A boolean for whether to apply sub-batch mixing to
      processed masks and all-one masks.
    rescale_features: A boolean for whether to rescale features by the ratio
      between total number of mask elements and kept mask elements.
    forced_mask_on_part_names: A list of standard names of parts of which the
      masks are forced on (by setting value to 1.0). See
      `KeypointProfile.get_standard_part_index` for standard part names.
    forced_mask_off_part_names: A list of standard names of parts of which the
      masks are forced off (by setting value to 0.0). See
      `KeypointProfile.get_standard_part_index` for standard part names.
    keypoint_profile_3d: A KeypointProfile3D object for input 3D keypoints. Only
      used when 3D-to-2D projection is involved.
    azimuth_range: A tuple for minimum and maximum azimuth angles to randomly
      rotate 3D keypoints with. For non-sequential inputs, a 2-tuple for
      (minimum angle, maximum angle) is expected. For sequence inputs, uses
      2-tuple to independently sample starting and ending camera angles, or uses
      4-tuple for (minimum starting angle, maximum starting angle, minimum angle
      increment, maximum angle increment) to first sample starting angles and
      add random delta angles to them as ending angles.
    elevation_range: A tuple for minimum and maximum elevation angles to
      randomly rotate 3D keypoints with. For non-sequential inputs, a 2-tuple
      for (minimum angle, maximum angle) is expected. For sequence inputs, uses
      2-tuple to independently sample starting and ending camera angles, or uses
      4-tuple for (minimum starting angle, maximum starting angle, minimum angle
      increment, maximum angle increment) to first sample starting angles and
      add random delta angles to them as ending angles.
    roll_range: A tuple for minimum and maximum roll angles to randomly rotate
      3D keypoints with. For non-sequential inputs, a 2-tuple for (minimum
      angle, maximum angle) is expected. For sequence inputs, uses 2-tuple to
      independently sample starting and ending camera angles, or uses 4-tuple
      for (minimum starting angle, maximum starting angle, minimum angle
      increment, maximum angle increment) to first sample starting angles and
      add random delta angles to them as ending angles.
    normalized_camera_depth_range: A tuple for minimum and maximum normalized
      camera depth for random camera augmentation. If empty, uses constant depth
      as 1 over the 2D pose normalization scale unit.
    sequential_inputs: A boolean flag indicating whether the inputs are
      sequential. If True, the input keypoints are supposed to be in shape [...,
      sequence_length, num_keypoints, keypoint_dim].
    seed: An integer for random seed.

  Returns:
    features: A tensor for input features. Shape = [..., feature_dim].
    side_outputs: A dictionary for side outputs, which includes
      `offset_points_2d` (shape = [..., 1, 2]) and `scale_distances_2d` (shape =
      [..., 1, 1]) if `normalize_keypoints_2d` is True.

  Raises:
    ValueError: If `model_input_keypoint_type` is not supported.
    ValueError: If `keypoint_dropout_probs` is not of length 2 or 3.
    ValueError: If `keypoint_profile_2d` is not specified when normalizing 2D
      keypoints.
    ValueError: If keypoint profile name is not 'LEGACY_2DCOCO13', '2DSTD13',
      or 'INTERNAL_2DSTD13' when applying structured keypoint dropout.
    ValueError: If number of instances is not 1 or 2.
    ValueError: If `keypoint_profile_2d` is not specified when forcing keypoint
      masks on.
  """
    keypoints_2d, keypoint_masks_2d = preprocess_keypoints_2d(
        keypoints_2d,
        keypoint_masks_2d,
        keypoints_3d,
        model_input_keypoint_type,
        keypoint_profile_2d=keypoint_profile_2d,
        keypoint_profile_3d=keypoint_profile_3d,
        azimuth_range=azimuth_range,
        elevation_range=elevation_range,
        roll_range=roll_range,
        normalized_camera_depth_range=normalized_camera_depth_range,
        sequential_inputs=sequential_inputs,
        seed=seed)

    side_outputs = {}

    if len(keypoint_dropout_probs) not in [2, 3]:
        raise ValueError('Invalid keypoint dropout probability tuple: `%s`.' %
                         str(keypoint_dropout_probs))

    if keypoint_dropout_probs[0] > 0.0 and keypoint_dropout_probs[1] > 0.0:
        instance_keypoint_masks_2d = apply_stratified_instance_keypoint_dropout(
            keypoint_masks_2d,
            probability_to_apply=keypoint_dropout_probs[0],
            probability_to_drop=keypoint_dropout_probs[1],
            seed=seed)

        if (sequential_inputs and len(keypoint_dropout_probs) == 3
                and keypoint_dropout_probs[2] > 0.0):
            sequence_keypoint_masks_2d = apply_stratified_sequence_keypoint_dropout(
                keypoint_masks_2d,
                probability_to_apply=keypoint_dropout_probs[0],
                probability_to_drop=keypoint_dropout_probs[1],
                seed=seed)
            sequence_axis = sequence_keypoint_masks_2d.shape.ndims - 1
            keypoint_masks_2d = data_utils.mix_batch(
                [sequence_keypoint_masks_2d], [instance_keypoint_masks_2d],
                axis=sequence_axis,
                keep_lhs_prob=keypoint_dropout_probs[2],
                seed=seed)[0]
        else:
            keypoint_masks_2d = instance_keypoint_masks_2d

    if structured_keypoint_mask_processor is not None:
        keypoint_masks_2d = structured_keypoint_mask_processor(
            keypoint_masks=keypoint_masks_2d,
            keypoint_profile=keypoint_profile_2d,
            seed=seed)

    if normalize_keypoints_2d:
        if keypoint_profile_2d is None:
            raise ValueError(
                'Failed to normalize 2D keypoints due to unspecified '
                'keypoint profile.')
        keypoints_2d, offset_points, scale_distances = (
            keypoint_profile_2d.normalize(keypoints_2d, keypoint_masks_2d))
        side_outputs.update({
            common.KEY_OFFSET_POINTS_2D: offset_points,
            common.KEY_SCALE_DISTANCES_2D: scale_distances
        })

    if uniform_keypoint_jittering_max_offset_2d > 0.0:
        keypoints_2d = _add_uniform_keypoint_jittering(
            keypoints_2d,
            max_jittering_offset=uniform_keypoint_jittering_max_offset_2d,
            seed=seed)

    if gaussian_keypoint_jittering_offset_stddev_2d > 0.0:
        keypoints_2d = _add_gaussian_keypoint_jittering(
            keypoints_2d,
            jittering_offset_stddev=
            gaussian_keypoint_jittering_offset_stddev_2d,
            seed=seed)

    if set_on_mask_for_non_anchors:
        non_anchor_indices = list(
            range(1,
                  keypoint_masks_2d.shape.as_list()[1]))
        if non_anchor_indices:
            keypoint_masks_2d = data_utils.update_sub_tensor(
                keypoint_masks_2d,
                indices=non_anchor_indices,
                axis=1,
                update_func=tf.ones_like)

    if mix_mask_sub_batches:
        keypoint_masks_2d = data_utils.mix_batch(
            [tf.ones_like(keypoint_masks_2d)], [keypoint_masks_2d], axis=1)[0]

    if forced_mask_on_part_names:
        keypoint_masks_2d = _override_keypoint_masks(
            keypoint_masks_2d,
            keypoint_profile=keypoint_profile_2d,
            part_names=forced_mask_on_part_names,
            overriding_func=tf.ones_like)

    if forced_mask_off_part_names:
        keypoint_masks_2d = _override_keypoint_masks(
            keypoint_masks_2d,
            keypoint_profile=keypoint_profile_2d,
            part_names=forced_mask_off_part_names,
            overriding_func=tf.zeros_like)

    if model_input_keypoint_mask_type in [
            common.MODEL_INPUT_KEYPOINT_MASK_TYPE_MASK_KEYPOINTS,
            common.MODEL_INPUT_KEYPOINT_MASK_TYPE_MASK_KEYPOINTS_AND_AS_INPUT
    ]:
        # Mask out invalid keypoints.
        keypoints_2d = tf.where(
            data_utils.tile_last_dims(
                tf.expand_dims(tf.math.equal(keypoint_masks_2d, 1.0), axis=-1),
                last_dim_multiples=[tf.shape(keypoints_2d)[-1]]), keypoints_2d,
            tf.zeros_like(keypoints_2d))

    side_outputs.update({
        common.KEY_PREPROCESSED_KEYPOINTS_2D:
        keypoints_2d,
        common.KEY_PREPROCESSED_KEYPOINT_MASKS_2D:
        keypoint_masks_2d,
    })

    features = keypoints_2d
    if model_input_keypoint_mask_type in [
            common.MODEL_INPUT_KEYPOINT_MASK_TYPE_AS_INPUT,
            common.MODEL_INPUT_KEYPOINT_MASK_TYPE_MASK_KEYPOINTS_AND_AS_INPUT
    ]:
        features = tf.concat(
            [keypoints_2d,
             tf.expand_dims(keypoint_masks_2d, axis=-1)],
            axis=-1)

    if rescale_features:
        # Scale up features to compensate for any keypoint masking.
        feature_rescales = keypoint_masks_2d.shape.as_list()[-1] / (
            tf.math.maximum(
                1e-12,
                tf.math.reduce_sum(keypoint_masks_2d, axis=-1, keepdims=True)))
        features *= tf.expand_dims(feature_rescales, axis=-1)

    features = data_utils.flatten_last_dims(features, num_last_dims=2)
    return features, side_outputs