Esempio n. 1
0
def _ragged_to_block_sparse(args):
    ragged_indices, offset = args
    assert ragged_rank(ragged_indices) == 2
    b = ragged_layers.value_rowids(ragged_indices)
    ragged_indices = ragged_layers.values(ragged_indices)
    i = ragged_layers.value_rowids(ragged_indices)
    b = tf.gather(b, i)
    j = ragged_layers.values(ragged_indices)
    j = j + tf.gather(offset, b)
    return i, j
Esempio n. 2
0
    def __init__(self, coords: FloatTensor):
        self._coords = coords

        batched_coords = as_ragged(pl.batch(pl.cache(coords)))
        self._batched_structure = RaggedStructure.from_ragged(batched_coords)

        model_coords = pl.model_input(batched_coords)
        self._model_structure = RaggedStructure.from_ragged(model_coords)

        self._batched_coords = ragged_layers.values(batched_coords)
        self._model_coords = ragged_layers.values(model_coords)
Esempio n. 3
0
 def __init__(  # pylint:disable=super-init-not-called
     self, in_cloud: Cloud, indices: IntTensor
 ):
     assert indices.shape[0] is None
     self._in_cloud = in_cloud
     self._indices = indices
     batched_indices = pl.batch(pl.cache(indices))
     self._batched_structure = RaggedStructure.from_ragged(batched_indices)
     # post-batch
     flat_batched_indices = ragged_layers.values(batched_indices) + tf.gather(
         self.in_cloud.batched_structure.row_starts,
         self.batched_structure.value_rowids,
     )
     model_indices = pl.model_input(
         self._batched_structure.as_ragged(flat_batched_indices)
     )
     self._batched_indices = flat_batched_indices
     self._model_indices = ragged_layers.values(model_indices)
     self._model_structure = RaggedStructure.from_ragged(model_indices)
Esempio n. 4
0
    def voxelize(self,
                 reduction,
                 features,
                 t_start,
                 t_end,
                 num_frames: int,
                 batch_size=None):
        static_shape = self.grid.static_shape
        assert static_shape is not None
        static_size = np.prod(static_shape, dtype=np.int64)
        assert features.shape[-1] is not None
        num_frames = np.array(num_frames, dtype=np.int64)

        batch_index = self.batched_structure.value_rowids
        batched_coords = maybe_cast(self.batched_coords, batch_index.dtype)
        time = self.frame_indices(t_start,
                                  t_end,
                                  num_frames,
                                  dtype=batch_index.dtype)
        dims = tf.stack(
            (
                maybe_cast(self.batched_structure.nrows, tf.int64),
                num_frames,
                static_size,
            ),
            axis=0,
        )
        indices = tf.stack((batch_index, time, batched_coords), axis=0)
        indices = grid_layers.ravel_multi_index(indices, dims, axis=0)

        indices = pl.model_input(indices)
        if batch_size is None:
            features.shape.assert_has_rank(3)
            batch_size = tf.shape(features)[0]
            assert is_ragged(features)
            features = ragged_wrappers.values(features)
        features.shape.assert_has_rank(2)
        features = reduction(
            features,
            indices,
            num_segments=batch_size * (num_frames * np.prod(static_shape)),
        )
        features = Lambda(
            tf.reshape,
            arguments=dict(shape=(-1, num_frames, *static_shape,
                                  features.shape[-1])),
        )(features)

        return features
Esempio n. 5
0
def neighborhood(
    in_cloud: Cloud,
    out_cloud: Cloud,
    radius: float,
    indices: tf.RaggedTensor,
    edge_features_fn: TensorMap,
    weight_fn: Optional[TensorMap] = None,
    normalize: bool = True,
    version: str = "v0",
) -> "Neighborhood":
    """
    Get a `Neighborhood` from relevant clouds and parameters.

    Args:
        in_cloud: input cloud.
        out_cloud: output cloud.
        radius: radius of ball search
        indices: pre-cache indices into `in_cloud` corresponding to neighbors
            in `in_cloud` of `out_cloud`.
        edge_features_fn: function producing (N, E?) edge features.
        weights_fn: weighting function to apply to edge features, function of the
            normalized edge length. If given, edge features are scaled by this value.
        normalize: if True, edge features are divided by the sum over the neighborhood
            of weights given by `weights_fn`.
        version: string indicating version which influences what to cache / when
            to apply various meta-ops. Supported:
            "v0": cache only minimal values and compute relative coords, edge features
                and weights in the model.
            "v1": cache minimal values, compute relative coords post-batch, edge
                features / weights in model.
            "v2": cache minimal values, compute relative coords, edge features and
                weights post-batch.
            "v3": cache relative coordinates, compute edge features / weights in model.
            "v4": cache relative coordinates, edge features and weights.

    Returns:
        `Neighborhood`.
    """

    def get_weights(rel_coords):
        if weight_fn is None:
            return None
        return weight_fn(tf.linalg.norm(rel_coords, axis=-1))

    cached_indices = pl.cache(indices)
    batched_indices = pl.batch(cached_indices)
    # post-batch
    i, j = ragged_to_block_sparse(
        batched_indices, in_cloud.batched_structure.row_starts
    )
    model_i = pl.model_input(i)
    model_j = pl.model_input(j)
    sparse_indices = tf_stack((model_i, model_j), axis=-1)
    if version == "v0":
        # compute rel coords / edges features in model
        rel_coords = tf.gather(in_cloud.model_coords / radius, model_j) - tf.gather(
            out_cloud.model_coords / radius, model_i
        )
        edge_features = edge_features_fn(rel_coords)
        weights = get_weights(rel_coords)
    elif version == "v1":
        # compute rel coords in batch, rest in model
        rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather(
            out_cloud.batched_coords / radius, i
        )
        rel_coords = pl.model_input(rel_coords)
        # edge features in model
        edge_features = edge_features_fn(rel_coords)
        weights = get_weights(rel_coords)
    elif version == "v2":
        # compute edge features in batch
        rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather(
            out_cloud.batched_coords / radius, i
        )
        edge_features = edge_features_fn(rel_coords)
        weights = get_weights(rel_coords)
        edge_features = pl.model_input(edge_features)
        weights = pl.model_input(weights)
    elif version == "v3":
        # cache relative coords
        i = ragged_layers.value_rowids(indices)
        j = ragged_layers.values(indices)
        rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather(
            out_cloud.coords / radius, i
        )
        rel_coords = pl.model_input(pl.batch(pl.cache(rel_coords)))
        rel_coords = ragged_layers.values(rel_coords)
        edge_features = edge_features_fn(rel_coords)
        weights = get_weights(rel_coords)
    elif version == "v4":
        # cache edge features / weights
        i = ragged_layers.value_rowids(indices)
        j = ragged_layers.values(indices)
        rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather(
            out_cloud.coords / radius, i
        )
        edge_features = edge_features_fn(rel_coords)  # (N, E?)
        edge_features = tf.transpose(edge_features, (1, 0))  # (E?, N)
        edge_features = pl.batch(pl.cache(edge_features))  # (B, E?, N)
        edge_features = pl.model_input(edge_features)
        edge_features = ragged_layers.values(edge_features)  # (BE, N)
        edge_features = tf.transpose(edge_features, (1, 0))  # (N, BE)

        weights = get_weights(rel_coords)  # (E,)
        weights = pl.model_input(pl.batch(pl.cache(weights)))  # (B, E?)
        weights = ragged_layers.values(weights)  # (BE,)
    else:
        raise ValueError(f"invalid version {version}")

    assert edge_features.shape[0] is not None

    return Neighborhood(
        in_cloud, out_cloud, sparse_indices, edge_features, weights, normalize=normalize
    )
Esempio n. 6
0
    def _batch_multi_partition(self) -> Tuple[tf.SparseTensor, ...]:
        num_partitions = self.num_partitions

        assert num_partitions > 1

        components = []

        rowids = tf.ragged.row_splits_to_segment_ids(self._splits,
                                                     out_type=self.dtype)
        partitions = maybe_cast(self._partitions, tf.int32)

        ijs = tf.dynamic_partition(
            tf_stack((rowids, self._indices),
                     axis=-1,
                     name="multi_partition_stack"),
            partitions,
            num_partitions,
        )

        # sorted transpose
        for ij in ijs:
            ij.shape.assert_has_rank(2)
            assert ij.shape[1] == 2
            # num required in tf-nightly (2.5)
            i, j = tf.unstack(ij, num=2, axis=-1)
            indices = ragged_wrappers.from_value_rowids(
                j, i, nrows=maybe_cast(self.out_stream.size, i.dtype))

            components.append(ragged_components(indices))

        all_ragged_indices = [
            pl.batch(
                ragged_wrappers.from_row_splits(
                    tf.cast(pl.cache(v), tf.int64),
                    tf.cast(pl.cache(rs), tf.int64))) for v, rs in components
        ]

        in_stream = self.in_stream
        out_stream = self.out_stream

        # ragged to sparse
        counts = []
        all_b = []
        all_i = []
        all_j = []
        for ragged_indices in all_ragged_indices:
            b = tf.ragged.row_splits_to_segment_ids(
                ragged_wrappers.row_splits(ragged_indices),
                out_type=self.dtype)
            ragged_indices = ragged_wrappers.values(ragged_indices)
            b = tf.repeat(b,
                          ragged_wrappers.row_lengths(ragged_indices),
                          axis=0)
            # sparse indices must eventually be int64
            i = tf.ragged.row_splits_to_segment_ids(
                ragged_wrappers.row_splits(ragged_indices), out_type=tf.int64)
            j = tf.cast(ragged_wrappers.values(ragged_indices), tf.int64)
            counts.append(ragged_wrappers.row_splits(ragged_indices)[-1])
            # total = tf.split(ragged_indices.row_splits, [-1, 1])[1]
            # counts.append(tf.squeeze(total, axis=0))
            all_b.append(b)
            all_i.append(i)
            all_j.append(j)

        # concat for efficient dt / block diagonalizing.
        cat_b = tf.concat(all_b, axis=0)
        cat_i = tf.concat(all_i, axis=0)
        cat_j = tf.concat(all_j, axis=0)

        # block diagonalize
        # skip i offset since it was automatically done in ragged batching
        cat_j = cat_j + tf.gather(in_stream.batched_structure.row_starts,
                                  cat_b)

        cat_dt = (tf.cast(
            tf.gather(out_stream.batched_times, cat_i) -
            tf.gather(in_stream.batched_times, cat_j),
            tf.float32,
        ) / self.decay_time)
        cat_ij = tf_stack((cat_i, cat_j), axis=-1)

        dense_shape = tf_stack(
            (
                maybe_cast(out_stream.batched_structure.total_size, tf.int64),
                maybe_cast(in_stream.batched_structure.total_size, tf.int64),
            ),
            axis=0,
        )
        # tf.SparseTensor indices and dense_shape must be int64
        if dense_shape.dtype != tf.int64:
            dense_shape = tf.cast(dense_shape, tf.int64)

        dts = tf.split(cat_dt, counts)
        ijs = tf.split(cat_ij, counts)

        return tuple(
            sparse_wrappers.SparseTensor(maybe_cast(
                ij, tf.int64), dt, dense_shape) for ij, dt in zip(ijs, dts))
Esempio n. 7
0
def ragged_components(rt):
    assert ragged_rank(rt) == 1
    return ragged_wrappers.values(rt), ragged_wrappers.row_splits(rt)
Esempio n. 8
0
import tensorflow as tf

from wtftf.ragged import layers as rl

rt = tf.keras.Input((None, ), ragged=True)
values = rl.values(rt)
row_ids = rl.value_rowids(rt)
seg_sum = tf.math.segment_sum(values, row_ids)

model = tf.keras.Model(rt, seg_sum)
print(model(tf.ragged.range(([2, 3]))))