def _ragged_to_block_sparse(args): ragged_indices, offset = args assert ragged_rank(ragged_indices) == 2 b = ragged_layers.value_rowids(ragged_indices) ragged_indices = ragged_layers.values(ragged_indices) i = ragged_layers.value_rowids(ragged_indices) b = tf.gather(b, i) j = ragged_layers.values(ragged_indices) j = j + tf.gather(offset, b) return i, j
def __init__(self, coords: FloatTensor): self._coords = coords batched_coords = as_ragged(pl.batch(pl.cache(coords))) self._batched_structure = RaggedStructure.from_ragged(batched_coords) model_coords = pl.model_input(batched_coords) self._model_structure = RaggedStructure.from_ragged(model_coords) self._batched_coords = ragged_layers.values(batched_coords) self._model_coords = ragged_layers.values(model_coords)
def __init__( # pylint:disable=super-init-not-called self, in_cloud: Cloud, indices: IntTensor ): assert indices.shape[0] is None self._in_cloud = in_cloud self._indices = indices batched_indices = pl.batch(pl.cache(indices)) self._batched_structure = RaggedStructure.from_ragged(batched_indices) # post-batch flat_batched_indices = ragged_layers.values(batched_indices) + tf.gather( self.in_cloud.batched_structure.row_starts, self.batched_structure.value_rowids, ) model_indices = pl.model_input( self._batched_structure.as_ragged(flat_batched_indices) ) self._batched_indices = flat_batched_indices self._model_indices = ragged_layers.values(model_indices) self._model_structure = RaggedStructure.from_ragged(model_indices)
def voxelize(self, reduction, features, t_start, t_end, num_frames: int, batch_size=None): static_shape = self.grid.static_shape assert static_shape is not None static_size = np.prod(static_shape, dtype=np.int64) assert features.shape[-1] is not None num_frames = np.array(num_frames, dtype=np.int64) batch_index = self.batched_structure.value_rowids batched_coords = maybe_cast(self.batched_coords, batch_index.dtype) time = self.frame_indices(t_start, t_end, num_frames, dtype=batch_index.dtype) dims = tf.stack( ( maybe_cast(self.batched_structure.nrows, tf.int64), num_frames, static_size, ), axis=0, ) indices = tf.stack((batch_index, time, batched_coords), axis=0) indices = grid_layers.ravel_multi_index(indices, dims, axis=0) indices = pl.model_input(indices) if batch_size is None: features.shape.assert_has_rank(3) batch_size = tf.shape(features)[0] assert is_ragged(features) features = ragged_wrappers.values(features) features.shape.assert_has_rank(2) features = reduction( features, indices, num_segments=batch_size * (num_frames * np.prod(static_shape)), ) features = Lambda( tf.reshape, arguments=dict(shape=(-1, num_frames, *static_shape, features.shape[-1])), )(features) return features
def neighborhood( in_cloud: Cloud, out_cloud: Cloud, radius: float, indices: tf.RaggedTensor, edge_features_fn: TensorMap, weight_fn: Optional[TensorMap] = None, normalize: bool = True, version: str = "v0", ) -> "Neighborhood": """ Get a `Neighborhood` from relevant clouds and parameters. Args: in_cloud: input cloud. out_cloud: output cloud. radius: radius of ball search indices: pre-cache indices into `in_cloud` corresponding to neighbors in `in_cloud` of `out_cloud`. edge_features_fn: function producing (N, E?) edge features. weights_fn: weighting function to apply to edge features, function of the normalized edge length. If given, edge features are scaled by this value. normalize: if True, edge features are divided by the sum over the neighborhood of weights given by `weights_fn`. version: string indicating version which influences what to cache / when to apply various meta-ops. Supported: "v0": cache only minimal values and compute relative coords, edge features and weights in the model. "v1": cache minimal values, compute relative coords post-batch, edge features / weights in model. "v2": cache minimal values, compute relative coords, edge features and weights post-batch. "v3": cache relative coordinates, compute edge features / weights in model. "v4": cache relative coordinates, edge features and weights. Returns: `Neighborhood`. """ def get_weights(rel_coords): if weight_fn is None: return None return weight_fn(tf.linalg.norm(rel_coords, axis=-1)) cached_indices = pl.cache(indices) batched_indices = pl.batch(cached_indices) # post-batch i, j = ragged_to_block_sparse( batched_indices, in_cloud.batched_structure.row_starts ) model_i = pl.model_input(i) model_j = pl.model_input(j) sparse_indices = tf_stack((model_i, model_j), axis=-1) if version == "v0": # compute rel coords / edges features in model rel_coords = tf.gather(in_cloud.model_coords / radius, model_j) - tf.gather( out_cloud.model_coords / radius, model_i ) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v1": # compute rel coords in batch, rest in model rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather( out_cloud.batched_coords / radius, i ) rel_coords = pl.model_input(rel_coords) # edge features in model edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v2": # compute edge features in batch rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather( out_cloud.batched_coords / radius, i ) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) edge_features = pl.model_input(edge_features) weights = pl.model_input(weights) elif version == "v3": # cache relative coords i = ragged_layers.value_rowids(indices) j = ragged_layers.values(indices) rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather( out_cloud.coords / radius, i ) rel_coords = pl.model_input(pl.batch(pl.cache(rel_coords))) rel_coords = ragged_layers.values(rel_coords) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v4": # cache edge features / weights i = ragged_layers.value_rowids(indices) j = ragged_layers.values(indices) rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather( out_cloud.coords / radius, i ) edge_features = edge_features_fn(rel_coords) # (N, E?) edge_features = tf.transpose(edge_features, (1, 0)) # (E?, N) edge_features = pl.batch(pl.cache(edge_features)) # (B, E?, N) edge_features = pl.model_input(edge_features) edge_features = ragged_layers.values(edge_features) # (BE, N) edge_features = tf.transpose(edge_features, (1, 0)) # (N, BE) weights = get_weights(rel_coords) # (E,) weights = pl.model_input(pl.batch(pl.cache(weights))) # (B, E?) weights = ragged_layers.values(weights) # (BE,) else: raise ValueError(f"invalid version {version}") assert edge_features.shape[0] is not None return Neighborhood( in_cloud, out_cloud, sparse_indices, edge_features, weights, normalize=normalize )
def _batch_multi_partition(self) -> Tuple[tf.SparseTensor, ...]: num_partitions = self.num_partitions assert num_partitions > 1 components = [] rowids = tf.ragged.row_splits_to_segment_ids(self._splits, out_type=self.dtype) partitions = maybe_cast(self._partitions, tf.int32) ijs = tf.dynamic_partition( tf_stack((rowids, self._indices), axis=-1, name="multi_partition_stack"), partitions, num_partitions, ) # sorted transpose for ij in ijs: ij.shape.assert_has_rank(2) assert ij.shape[1] == 2 # num required in tf-nightly (2.5) i, j = tf.unstack(ij, num=2, axis=-1) indices = ragged_wrappers.from_value_rowids( j, i, nrows=maybe_cast(self.out_stream.size, i.dtype)) components.append(ragged_components(indices)) all_ragged_indices = [ pl.batch( ragged_wrappers.from_row_splits( tf.cast(pl.cache(v), tf.int64), tf.cast(pl.cache(rs), tf.int64))) for v, rs in components ] in_stream = self.in_stream out_stream = self.out_stream # ragged to sparse counts = [] all_b = [] all_i = [] all_j = [] for ragged_indices in all_ragged_indices: b = tf.ragged.row_splits_to_segment_ids( ragged_wrappers.row_splits(ragged_indices), out_type=self.dtype) ragged_indices = ragged_wrappers.values(ragged_indices) b = tf.repeat(b, ragged_wrappers.row_lengths(ragged_indices), axis=0) # sparse indices must eventually be int64 i = tf.ragged.row_splits_to_segment_ids( ragged_wrappers.row_splits(ragged_indices), out_type=tf.int64) j = tf.cast(ragged_wrappers.values(ragged_indices), tf.int64) counts.append(ragged_wrappers.row_splits(ragged_indices)[-1]) # total = tf.split(ragged_indices.row_splits, [-1, 1])[1] # counts.append(tf.squeeze(total, axis=0)) all_b.append(b) all_i.append(i) all_j.append(j) # concat for efficient dt / block diagonalizing. cat_b = tf.concat(all_b, axis=0) cat_i = tf.concat(all_i, axis=0) cat_j = tf.concat(all_j, axis=0) # block diagonalize # skip i offset since it was automatically done in ragged batching cat_j = cat_j + tf.gather(in_stream.batched_structure.row_starts, cat_b) cat_dt = (tf.cast( tf.gather(out_stream.batched_times, cat_i) - tf.gather(in_stream.batched_times, cat_j), tf.float32, ) / self.decay_time) cat_ij = tf_stack((cat_i, cat_j), axis=-1) dense_shape = tf_stack( ( maybe_cast(out_stream.batched_structure.total_size, tf.int64), maybe_cast(in_stream.batched_structure.total_size, tf.int64), ), axis=0, ) # tf.SparseTensor indices and dense_shape must be int64 if dense_shape.dtype != tf.int64: dense_shape = tf.cast(dense_shape, tf.int64) dts = tf.split(cat_dt, counts) ijs = tf.split(cat_ij, counts) return tuple( sparse_wrappers.SparseTensor(maybe_cast( ij, tf.int64), dt, dense_shape) for ij, dt in zip(ijs, dts))
def ragged_components(rt): assert ragged_rank(rt) == 1 return ragged_wrappers.values(rt), ragged_wrappers.row_splits(rt)
import tensorflow as tf from wtftf.ragged import layers as rl rt = tf.keras.Input((None, ), ragged=True) values = rl.values(rt) row_ids = rl.value_rowids(rt) seg_sum = tf.math.segment_sum(values, row_ids) model = tf.keras.Model(rt, seg_sum) print(model(tf.ragged.range(([2, 3]))))