def _batch_single_partition(self, indices, splits) -> tf.SparseTensor: in_stream = self._in_stream out_stream = self._out_stream splits = maybe_cast(splits, self.dtype) indices = maybe_cast(indices, self.dtype) indices = pl.cache(indices) splits = pl.cache(splits) ragged_indices = ragged_wrappers.from_row_splits( maybe_cast(indices, tf.int64), maybe_cast(splits, tf.int64)) ragged_indices = pl.batch(ragged_indices) b, i, j = sparse_layers.ragged_to_sparse_indices( ragged_indices, in_stream.batched_structure.row_starts) del b dt = (tf.cast( tf.gather(out_stream.batched_times, i) - tf.gather(in_stream.batched_times, j), tf.float32, ) / self.decay_time) dense_shape = tf_stack( ( out_stream.batched_structure.total_size, in_stream.batched_structure.total_size, ), axis=0, ) ij = tf_stack((i, j), axis=-1) assert ij.dtype == tf.int64 assert dense_shape.dtype == tf.int64 dt = sparse_wrappers.SparseTensor(maybe_cast(ij, tf.int64), dt, dense_shape) return dt
def ip_ds_pool_build( coords, labels, r0=0.1125, k0=32, num_classes=10, weight_fn=hat_weight, normalize=True, ): # in-place down-sample pooling in_cloud = comp.Cloud(coords) radius = r0 out_cloud, ip_neigh, ds_neigh = in_cloud.sample_query( in_place_radius=radius, in_place_k=k0, down_sample_radius=radius * np.sqrt(2), down_sample_k=k0 * 2, edge_features_fn=polynomial_edge_features, weight_fn=weight_fn, normalize=normalize, ) features = None features = ip_neigh.convolve(features, filters=5, activation="relu") features = ip_neigh.convolve(features, filters=7, activation="relu") features = ds_neigh.convolve(features, filters=11, activation="relu") features = tf.math.unsorted_segment_max( features, out_cloud.model_structure.value_rowids, out_cloud.model_structure.nrows, ) logits = tf.keras.layers.Dense(num_classes)(features) return logits, pl.batch(pl.cache(labels))
def ip_pool_build( coords, labels, r0=0.1125, k0=32, num_classes=10, weight_fn=hat_weight, normalize=True, ): # in-place pooling del num_classes features = None labels = pl.batch(pl.cache(labels)) cloud = comp.Cloud(coords) radius = r0 ip_neigh = cloud.query(radius, polynomial_edge_features, weight_fn, k0, normalize=normalize) features = ip_neigh.convolve(features, filters=3, activation="relu") features = ip_neigh.convolve(features, filters=2, activation="relu") features = tf.math.unsorted_segment_max(features, cloud.model_structure.value_rowids, cloud.model_structure.nrows) return features, labels
def __init__(self, coords: FloatTensor): self._coords = coords batched_coords = as_ragged(pl.batch(pl.cache(coords))) self._batched_structure = RaggedStructure.from_ragged(batched_coords) model_coords = pl.model_input(batched_coords) self._model_structure = RaggedStructure.from_ragged(model_coords) self._batched_coords = ragged_layers.values(batched_coords) self._model_coords = ragged_layers.values(model_coords)
def feature_inputs(features, features_type="none"): if features_type == "none": return None if features_type == "binary": features = tf.split(features, [1, -1], axis=1)[0] features = tf.squeeze(features, axis=1) > 0.5 elif features_type == "float": pass else: raise ValueError( '`features_type` must be "none", "binary" or "float",' f" got {features_type}" ) out = pl.model_input(pl.batch(pl.cache(features))) out = Lambda(lambda x: tf.identity(x.values))(out) return out
def __init__( # pylint:disable=super-init-not-called self, in_cloud: Cloud, indices: IntTensor ): assert indices.shape[0] is None self._in_cloud = in_cloud self._indices = indices batched_indices = pl.batch(pl.cache(indices)) self._batched_structure = RaggedStructure.from_ragged(batched_indices) # post-batch flat_batched_indices = ragged_layers.values(batched_indices) + tf.gather( self.in_cloud.batched_structure.row_starts, self.batched_structure.value_rowids, ) model_indices = pl.model_input( self._batched_structure.as_ragged(flat_batched_indices) ) self._batched_indices = flat_batched_indices self._model_indices = ragged_layers.values(model_indices) self._model_structure = RaggedStructure.from_ragged(model_indices)
def ip_build( coords, labels, r0=0.1125, k0=32, num_classes=10, weight_fn=hat_weight, normalize=True, ): # in-place del num_classes features = None labels = pl.batch(pl.cache(labels)) cloud = comp.Cloud(coords) radius = r0 ip_neigh = cloud.query(radius, polynomial_edge_features, weight_fn, k0, normalize=normalize) features = ip_neigh.convolve(features, filters=3, activation="relu") return features, labels
def neighborhood( in_cloud: Cloud, out_cloud: Cloud, radius: float, indices: tf.RaggedTensor, edge_features_fn: TensorMap, weight_fn: Optional[TensorMap] = None, normalize: bool = True, version: str = "v0", ) -> "Neighborhood": """ Get a `Neighborhood` from relevant clouds and parameters. Args: in_cloud: input cloud. out_cloud: output cloud. radius: radius of ball search indices: pre-cache indices into `in_cloud` corresponding to neighbors in `in_cloud` of `out_cloud`. edge_features_fn: function producing (N, E?) edge features. weights_fn: weighting function to apply to edge features, function of the normalized edge length. If given, edge features are scaled by this value. normalize: if True, edge features are divided by the sum over the neighborhood of weights given by `weights_fn`. version: string indicating version which influences what to cache / when to apply various meta-ops. Supported: "v0": cache only minimal values and compute relative coords, edge features and weights in the model. "v1": cache minimal values, compute relative coords post-batch, edge features / weights in model. "v2": cache minimal values, compute relative coords, edge features and weights post-batch. "v3": cache relative coordinates, compute edge features / weights in model. "v4": cache relative coordinates, edge features and weights. Returns: `Neighborhood`. """ def get_weights(rel_coords): if weight_fn is None: return None return weight_fn(tf.linalg.norm(rel_coords, axis=-1)) cached_indices = pl.cache(indices) batched_indices = pl.batch(cached_indices) # post-batch i, j = ragged_to_block_sparse( batched_indices, in_cloud.batched_structure.row_starts ) model_i = pl.model_input(i) model_j = pl.model_input(j) sparse_indices = tf_stack((model_i, model_j), axis=-1) if version == "v0": # compute rel coords / edges features in model rel_coords = tf.gather(in_cloud.model_coords / radius, model_j) - tf.gather( out_cloud.model_coords / radius, model_i ) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v1": # compute rel coords in batch, rest in model rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather( out_cloud.batched_coords / radius, i ) rel_coords = pl.model_input(rel_coords) # edge features in model edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v2": # compute edge features in batch rel_coords = tf.gather(in_cloud.batched_coords / radius, j) - tf.gather( out_cloud.batched_coords / radius, i ) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) edge_features = pl.model_input(edge_features) weights = pl.model_input(weights) elif version == "v3": # cache relative coords i = ragged_layers.value_rowids(indices) j = ragged_layers.values(indices) rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather( out_cloud.coords / radius, i ) rel_coords = pl.model_input(pl.batch(pl.cache(rel_coords))) rel_coords = ragged_layers.values(rel_coords) edge_features = edge_features_fn(rel_coords) weights = get_weights(rel_coords) elif version == "v4": # cache edge features / weights i = ragged_layers.value_rowids(indices) j = ragged_layers.values(indices) rel_coords = tf.gather(in_cloud.coords / radius, j) - tf.gather( out_cloud.coords / radius, i ) edge_features = edge_features_fn(rel_coords) # (N, E?) edge_features = tf.transpose(edge_features, (1, 0)) # (E?, N) edge_features = pl.batch(pl.cache(edge_features)) # (B, E?, N) edge_features = pl.model_input(edge_features) edge_features = ragged_layers.values(edge_features) # (BE, N) edge_features = tf.transpose(edge_features, (1, 0)) # (N, BE) weights = get_weights(rel_coords) # (E,) weights = pl.model_input(pl.batch(pl.cache(weights))) # (B, E?) weights = ragged_layers.values(weights) # (BE,) else: raise ValueError(f"invalid version {version}") assert edge_features.shape[0] is not None return Neighborhood( in_cloud, out_cloud, sparse_indices, edge_features, weights, normalize=normalize )
def _batch_multi_partition(self) -> Tuple[tf.SparseTensor, ...]: num_partitions = self.num_partitions assert num_partitions > 1 components = [] rowids = tf.ragged.row_splits_to_segment_ids(self._splits, out_type=self.dtype) partitions = maybe_cast(self._partitions, tf.int32) ijs = tf.dynamic_partition( tf_stack((rowids, self._indices), axis=-1, name="multi_partition_stack"), partitions, num_partitions, ) # sorted transpose for ij in ijs: ij.shape.assert_has_rank(2) assert ij.shape[1] == 2 # num required in tf-nightly (2.5) i, j = tf.unstack(ij, num=2, axis=-1) indices = ragged_wrappers.from_value_rowids( j, i, nrows=maybe_cast(self.out_stream.size, i.dtype)) components.append(ragged_components(indices)) all_ragged_indices = [ pl.batch( ragged_wrappers.from_row_splits( tf.cast(pl.cache(v), tf.int64), tf.cast(pl.cache(rs), tf.int64))) for v, rs in components ] in_stream = self.in_stream out_stream = self.out_stream # ragged to sparse counts = [] all_b = [] all_i = [] all_j = [] for ragged_indices in all_ragged_indices: b = tf.ragged.row_splits_to_segment_ids( ragged_wrappers.row_splits(ragged_indices), out_type=self.dtype) ragged_indices = ragged_wrappers.values(ragged_indices) b = tf.repeat(b, ragged_wrappers.row_lengths(ragged_indices), axis=0) # sparse indices must eventually be int64 i = tf.ragged.row_splits_to_segment_ids( ragged_wrappers.row_splits(ragged_indices), out_type=tf.int64) j = tf.cast(ragged_wrappers.values(ragged_indices), tf.int64) counts.append(ragged_wrappers.row_splits(ragged_indices)[-1]) # total = tf.split(ragged_indices.row_splits, [-1, 1])[1] # counts.append(tf.squeeze(total, axis=0)) all_b.append(b) all_i.append(i) all_j.append(j) # concat for efficient dt / block diagonalizing. cat_b = tf.concat(all_b, axis=0) cat_i = tf.concat(all_i, axis=0) cat_j = tf.concat(all_j, axis=0) # block diagonalize # skip i offset since it was automatically done in ragged batching cat_j = cat_j + tf.gather(in_stream.batched_structure.row_starts, cat_b) cat_dt = (tf.cast( tf.gather(out_stream.batched_times, cat_i) - tf.gather(in_stream.batched_times, cat_j), tf.float32, ) / self.decay_time) cat_ij = tf_stack((cat_i, cat_j), axis=-1) dense_shape = tf_stack( ( maybe_cast(out_stream.batched_structure.total_size, tf.int64), maybe_cast(in_stream.batched_structure.total_size, tf.int64), ), axis=0, ) # tf.SparseTensor indices and dense_shape must be int64 if dense_shape.dtype != tf.int64: dense_shape = tf.cast(dense_shape, tf.int64) dts = tf.split(cat_dt, counts) ijs = tf.split(cat_ij, counts) return tuple( sparse_wrappers.SparseTensor(maybe_cast( ij, tf.int64), dt, dense_shape) for ij, dt in zip(ijs, dts))
def batched_coords(self): coords = pl.cache(self._coords) coords = pl.batch(coords) return ragged_wrappers.flat_values(coords)
def cached_times(self): return pl.cache(self._times)
def inception_vox_pooling( features, labels, sample_weight=None, num_classes: int = 10, grid_shape: Tuple[int, int] = (128, 128), decay_time: int = 2000, spatial_buffer: int = 32, reset_potential: float = -3.0, threshold: float = 1.5, filters0: int = 8, kt0: int = 4, hidden_units: Sequence[int] = (256, ), dropout_rate: float = 0.5, decay_time_expansion_rate: float = 2.0, num_levels: int = 5, activation: Union[str, Callable] = "relu", recenter: bool = True, vox_reduction: str = "mean", vox_start: int = 2, initial_pooling=None, max_events=None, ): """ `meta_model.pipeline` build function that performs event-stream classification. Loosely inspired by inception, it has blocks that have 3x3 convolutions, a `t` convolution (kernel-mask [[0, 1, 0], [1, 1, 1], [0, 1, 0]]), temporally-deep 1x1 convolutions and feature-deep event-wise convolutions. Most output streams are voxelized in xyt space which form a separate branch of the network as per the diagram below. Stream0 -> Stream1 -> Stream2 -> Stream3 | | | v v v Vox1 -> Vox2 -> Vox3 | v MLP | v logits The above network has num_levels=3, vox_start=1. Args: features: Dict of KerasTensor with pre-cached "time", "coords", "polarity" keys. labels: Int KerasTensor with pre-cached class labels. sample_weight: Possible KerasTensor with pre-cached example weights. num_classes: number of classes for classification problem. grid_shape: spatial shape of input grid. decay_time: time-scale for initial convolutions. spatial_buffer: buffer size used in neighborhood preprocessing. Each pixel will store up to this many input events when computing neighbors. reset_potential: used in leaky integrate and fire for stream subsampling. threshold: used in leaky integrate and fire for stream subsampling. filters0: base number of filters in first block. kt0: base temporal kernal size in first block. hidden_units: units in hidden layers of final MLP. dropout_rate: rate used in Dropout layers. decay_time_expansion_rate: factory by which `decay_time` is expanded each block. num_levels: number of blocks of convolutions. activation: activation function / string ID for activations used throughout. recenter: if True, streams are initially shifted to the image center. vox_reduction: string indicating the mechanism by which events are accumulated across x-y-t voxels. vox_start: the level at which voxelization begins. initial_pooling: spatial pooling applied before any learning begins. max_events: maximum number of input events before truncation. Returns: (logits, labels) or (logits, labels, sample_weight) for trained model. See also: - `meta_models.pipeline.build_pipelined_model` - `kblocks.trainables.build_meta_model_trainable` """ if vox_reduction == "max": reduction = tf.math.unsorted_segment_max else: assert vox_reduction == "mean" reduction = tf.math.unsorted_segment_mean times = features["time"] coords = features["coords"] polarity = features["polarity"] if max_events is not None: times = times[:max_events] coords = coords[:max_events] polarity = polarity[:max_events] if initial_pooling is not None: if grid_shape is not None: grid_shape = tuple(g // initial_pooling for g in grid_shape) coords = coords // initial_pooling if recenter: max_coords = tf.reduce_max(coords, axis=0) offset = (tf.constant(grid_shape, dtype=coords.dtype) - max_coords) // 2 coords = coords + offset times = times - times[0] t_start = None filters = filters0 activation = tf.keras.activations.get(activation) lif_kwargs = dict(reset_potential=reset_potential, threshold=threshold) grid = comp.Grid(grid_shape) link = grid.link((3, 3), (2, 2), (1, 1)) in_stream: comp.SpatialStream = comp.SpatialStream(grid, times, coords) t_end = in_stream.cached_times[-1] + 1 t_end = pl.batch(t_end) out_stream = comp.spatial_leaky_integrate_and_fire( in_stream, link, decay_time=decay_time, **lif_kwargs, ) features = pl.model_input(pl.batch(pl.cache(polarity))) batch_size, features = tf.keras.layers.Lambda( lambda x: (x.nrows(), tf.identity(x.values)))(features) num_frames = 2**(num_levels - 1) convolver = comp.spatio_temporal_convolver( link, in_stream, out_stream, decay_time=decay_time, spatial_buffer_size=spatial_buffer, ) features = convolver.convolve(features, filters=filters, temporal_kernel_size=kt0, activation=activation) features = layers.BatchNormalization()(features) features = Dropout(dropout_rate)(features) in_stream = out_stream del out_stream del convolver decay_time = int(decay_time * decay_time_expansion_rate) t_kernel = np.zeros((5, 5), dtype=np.bool) t_kernel[2] = True t_kernel[:, 2] = True def do_in_place(in_stream: comp.SpatialStream, features, filters): link = in_stream.grid.partial_self_link(t_kernel) t_convolver = comp.spatio_temporal_convolver( link, in_stream, in_stream, decay_time=decay_time, spatial_buffer_size=spatial_buffer, ) p_convolver = comp.pointwise_convolver( in_stream, in_stream, spatial_buffer_size=spatial_buffer, decay_time=decay_time * 4, ) # (5x1 + 1x5)xt ft = t_convolver.convolve(features, filters=filters, temporal_kernel_size=kt0) # 1x1x4t fp = p_convolver.convolve(features, filters=filters, temporal_kernel_size=4 * kt0) # 1x1x1 fc = layers.Dense(units=filters * 4, activation=activation)(features) fc = layers.Dense(units=filters)(fc) branched = activation(ft + fp + fc) branched = layers.BatchNormalization()(branched) features = features + branched return features def merge_voxel_features(in_stream: comp.SpatialStream, features, voxel_features, num_frames): out_voxel_features = in_stream.voxelize(reduction, features, t_start, t_end, num_frames, batch_size) out_voxel_features = layers.BatchNormalization()(out_voxel_features) if voxel_features is None: return out_voxel_features # residual connection voxel_features = layers.Conv3D(features.shape[-1], 2, 2, activation=activation, padding="same")(voxel_features) voxel_features = layers.BatchNormalization()(voxel_features) voxel_features = voxel_features + out_voxel_features return voxel_features voxel_features = None for i in range(num_levels - 1): # in place features = do_in_place(in_stream, features, filters) if i >= vox_start: voxel_features = merge_voxel_features(in_stream, features, voxel_features, num_frames) num_frames //= 2 filters *= 2 link = in_stream.grid.link((3, 3), (2, 2), (1, 1)) out_stream = comp.spatial_leaky_integrate_and_fire( in_stream, link, decay_time=decay_time, **lif_kwargs, ) ds_convolver = comp.spatio_temporal_convolver( link, in_stream, out_stream, decay_time=decay_time, spatial_buffer_size=spatial_buffer, ) features = ds_convolver.convolve(features, filters=filters, temporal_kernel_size=kt0, activation=activation) features = layers.BatchNormalization()(features) features = Dropout(dropout_rate)(features) in_stream = out_stream del out_stream decay_time = int(decay_time * decay_time_expansion_rate) features = do_in_place(in_stream, features, filters) voxel_features = merge_voxel_features(in_stream, features, voxel_features, num_frames) assert num_frames == 1 assert voxel_features.shape[1] == 1 image_features = Lambda(tf.squeeze, arguments=dict(axis=1))(voxel_features) image_features = layers.Dense(2 * filters)(image_features) features = tf.keras.layers.GlobalMaxPooling2D()(image_features) features = layers.BatchNormalization()(features) features = Dropout(dropout_rate)(features) for h in hidden_units: features = layers.Dense(h, activation=activation)(features) features = layers.BatchNormalization()(features) features = Dropout(dropout_rate)(features) logits = layers.Dense(num_classes, activation=None, name="logits")(features) labels = pl.batch(pl.cache(labels)) if sample_weight is None: return logits, labels sample_weight = pl.batch(pl.cache(sample_weight)) return logits, labels, sample_weight
def _finalize(logits: FloatTensor, labels: IntTensor, weights: Optional[FloatTensor]): labels = pl.batch(pl.cache(labels)) if weights is None: return logits, labels return logits, labels, pl.batch(pl.cache(weights))