def test_cat(self): tensor1 = MaskedTensor(tf.constant([1, 2])) tensor2 = MaskedTensor(tf.constant([3, 4])) stack = MaskedTensorflow.concat([tensor1, tensor2], axis=0) res = MaskedTensor(tf.constant([[1, 2, 3, 4]])) self.assertTrue(tf.reduce_all(stack == res), msg="Cat is not equal to expected")
def process_datum(datum, augment=False): """Prepare every datum to be an input-output pair for training / eval. Supports data augmentation only including frames dropout. Frame dropout affects the FPS, which does change the optical flow. Args: datum (Dict[str, tf.Tensor]): a dictionary of tensors loaded from the tfrecord. augment (bool): should apply data augmentation on the datum? Returns: dict(Dict[str, tf.Tensor]): dictionary including "src" and "tgt" tensors """ masked_tensor = MaskedTensor(tensor=datum["pose_data_tensor"], mask=datum["pose_data_mask"]) pose_body = TensorflowPoseBody(fps=datum["fps"], data=masked_tensor, confidence=datum["pose_confidence"]) pose = Pose(header=get_pose_header(), body=pose_body) tgt = datum["tgt"] fps = pose.body.fps frames = datum["frames"] if augment: pose, selected_indexes = pose.frame_dropout(FLAGS.frame_dropout_std) tgt = tf.gather(tgt, selected_indexes) new_frames = tf.cast(tf.size(tgt), dtype=fps.dtype) fps = tf.math.maximum(minimum_fps, (new_frames / frames) * fps) frames = new_frames flow = optical_flow(pose.body.data, fps) tgt = tgt[1:] # First frame tag is not used return {"src": flow, "tgt": tgt}
def func(*args, **kwargs): if len(args) > 0 and isinstance(args[0], MaskedTensor): args = list(args) mask = args[0].mask args[0] = args[0].tensor res = getattr(tensorflow, attr)(*args, **kwargs) if attr in TensorflowFallback.doesnt_change_mask: return MaskedTensor(res, mask) else: return res else: # If this action is done on an unmasked tensor return getattr(tensorflow, attr)(*args, **kwargs)
def test_not_implemented_method(self): tensor = MaskedTensor(tensor=tf.constant([1, 2, 3])) torch_sum = MaskedTensorflow.sum(tensor) self.assertEqual(torch_sum, tf.constant(6))
def zeros(size, dtype=None) -> MaskedTensor: tensor = tensorflow.zeros(size, dtype=dtype) mask = tensorflow.zeros(size, dtype=tensorflow.bool) return MaskedTensor(tensor=tensor, mask=mask)
def stack(tensors: List[MaskedTensor], axis: int) -> MaskedTensor: tensor = tensorflow.stack([t.tensor for t in tensors], axis=axis) mask = tensorflow.stack([t.mask for t in tensors], axis=axis) return MaskedTensor(tensor=tensor, mask=mask)
def concat(tensors: List[Union[MaskedTensor, tensorflow.Tensor]], axis: int) -> MaskedTensor: tensors: List[MaskedTensor] = [t if isinstance(t, MaskedTensor) else MaskedTensor(tensor=t) for t in tensors] tensor = tensorflow.concat([t.tensor for t in tensors], axis=axis) mask = tensorflow.concat([t.mask for t in tensors], axis=axis) return MaskedTensor(tensor=tensor, mask=mask)