示例#1
0
 def test_cat(self):
     tensor1 = MaskedTensor(tf.constant([1, 2]))
     tensor2 = MaskedTensor(tf.constant([3, 4]))
     stack = MaskedTensorflow.concat([tensor1, tensor2], axis=0)
     res = MaskedTensor(tf.constant([[1, 2, 3, 4]]))
     self.assertTrue(tf.reduce_all(stack == res),
                     msg="Cat is not equal to expected")
示例#2
0
def process_datum(datum, augment=False):
    """Prepare every datum to be an input-output pair for training / eval.

    Supports data augmentation only including frames dropout.
    Frame dropout affects the FPS, which does change the optical flow.

    Args:
        datum (Dict[str, tf.Tensor]): a dictionary of tensors loaded from the
          tfrecord.
        augment (bool): should apply data augmentation on the datum?

    Returns:
       dict(Dict[str, tf.Tensor]): dictionary including "src" and "tgt" tensors
    """
    masked_tensor = MaskedTensor(tensor=datum["pose_data_tensor"], mask=datum["pose_data_mask"])
    pose_body = TensorflowPoseBody(fps=datum["fps"], data=masked_tensor, confidence=datum["pose_confidence"])
    pose = Pose(header=get_pose_header(), body=pose_body)
    tgt = datum["tgt"]

    fps = pose.body.fps
    frames = datum["frames"]

    if augment:
        pose, selected_indexes = pose.frame_dropout(FLAGS.frame_dropout_std)
        tgt = tf.gather(tgt, selected_indexes)

        new_frames = tf.cast(tf.size(tgt), dtype=fps.dtype)

        fps = tf.math.maximum(minimum_fps, (new_frames / frames) * fps)
        frames = new_frames

    flow = optical_flow(pose.body.data, fps)
    tgt = tgt[1:]  # First frame tag is not used

    return {"src": flow, "tgt": tgt}
示例#3
0
        def func(*args, **kwargs):
            if len(args) > 0 and isinstance(args[0], MaskedTensor):
                args = list(args)
                mask = args[0].mask
                args[0] = args[0].tensor

                res = getattr(tensorflow, attr)(*args, **kwargs)
                if attr in TensorflowFallback.doesnt_change_mask:
                    return MaskedTensor(res, mask)
                else:
                    return res

            else:  # If this action is done on an unmasked tensor
                return getattr(tensorflow, attr)(*args, **kwargs)
示例#4
0
 def test_not_implemented_method(self):
     tensor = MaskedTensor(tensor=tf.constant([1, 2, 3]))
     torch_sum = MaskedTensorflow.sum(tensor)
     self.assertEqual(torch_sum, tf.constant(6))
示例#5
0
 def zeros(size, dtype=None) -> MaskedTensor:
     tensor = tensorflow.zeros(size, dtype=dtype)
     mask = tensorflow.zeros(size, dtype=tensorflow.bool)
     return MaskedTensor(tensor=tensor, mask=mask)
示例#6
0
 def stack(tensors: List[MaskedTensor], axis: int) -> MaskedTensor:
     tensor = tensorflow.stack([t.tensor for t in tensors], axis=axis)
     mask = tensorflow.stack([t.mask for t in tensors], axis=axis)
     return MaskedTensor(tensor=tensor, mask=mask)
示例#7
0
 def concat(tensors: List[Union[MaskedTensor, tensorflow.Tensor]], axis: int) -> MaskedTensor:
     tensors: List[MaskedTensor] = [t if isinstance(t, MaskedTensor) else MaskedTensor(tensor=t) for t in tensors]
     tensor = tensorflow.concat([t.tensor for t in tensors], axis=axis)
     mask = tensorflow.concat([t.mask for t in tensors], axis=axis)
     return MaskedTensor(tensor=tensor, mask=mask)