示例#1
0
def set_abstraction(
    features,
    neighborhood,
    edge_mlp,
    reduction=tf.reduce_max,
    node_mlp=None,
    coords_as_features=True,
):
    """
    Reimplementation of original `pointnet_sa_module` function.

    Args:
        features: [b, n_i?, filters_in] float32 tensor of flattend batched point
            features.
        neighborhood: `deepcloud.neigh.Neighborhood` instance with `n_i` inputs
            and `n_o` output points.
        edge_mlp: callable acting on each edge features.
        reduction: operation to reduce neighborhoods to point features.
        node_mpl: callacble acting on each point after reduction.
        coords_as_features: if True, use relative coords in neighborhood as well
            as features in edges.

    Returns:
        features: [b, n_o?, filters_o] float32 array, where filters_o is the
          number of output features of `edge_mlp` if `node_mlp` is None else the
          number of output features of `node_mlp`.
    """
    def flat_rel_coords():
        return b.as_batched_model_input(
            neighborhood.rel_coords.flat_values).flat_values

    if features is None:
        features = flat_rel_coords()
    else:
        features = layer_utils.flatten_leading_dims(features)
        offset_batched_neighbors = neighborhood.offset_batched_neighbors
        if coords_as_features:
            features = tf.gather(features,
                                 offset_batched_neighbors.flat_values)
            features = layers.Lambda(tf.concat, arguments=dict(axis=-1))(
                [features, flat_rel_coords()])
        else:
            # more efficient than original implementation
            features = edge_mlp(features)
            features = tf.gather(features,
                                 offset_batched_neighbors.flat_values)

    # features is not flat, [B, f]
    features = tf.RaggedTensor.from_nested_row_splits(
        features, offset_batched_neighbors.nested_row_splits)
    # features is now [b, n_o?, k?, E]
    features = ragged_lambda(reduction, arguments=dict(axis=-2))(features)
    # features is now [b, n_o?, E]
    if node_mlp is not None:
        if isinstance(features, tf.RaggedTensor):
            features = tf.ragged.map_flat_values(node_mlp, features)
        else:
            features = node_mlp(features)

    return features
示例#2
0
文件: query.py 项目: jackd/deep-cloud
def query_pairs(coords, radius, name=None, max_neighbors=None):
    if isinstance(radius, tf.Tensor):
        kwargs = {}
        args = coords, radius
    else:
        kwargs = dict(radius=radius)
        args = coords
    kwargs['max_neighbors'] = max_neighbors
    neighbors = ragged_lambda(_query_pairs, arguments=kwargs)(args)
    return neighbors
示例#3
0
 def _batched_ragged(self, rt):
     if rt in self._batched_inputs_dict:
         return self._batched_inputs_dict[rt]
     size = self._batched_fixed_tensor(
         ragged_layers.ragged_lambda(lambda x: x.nrows())(rt))
     nested_row_lengths = ragged_layers.nested_row_lengths(rt)
     nested_row_lengths = [
         self._batched_tensor(rl) for rl in nested_row_lengths
     ]
     nested_row_lengths = [
         layer_utils.flatten_leading_dims(rl, 2)
         for rl in nested_row_lengths
     ]
     values = layer_utils.flatten_leading_dims(
         self._batched_tensor(rt.flat_values), 2)
     out = ragged_layers.ragged_from_nested_row_lengths(
         values, [size] + nested_row_lengths)
     self._batched_inputs_dict[rt] = out
     return out
示例#4
0
文件: cloud.py 项目: jackd/deep-cloud
def get_relative_coords(in_coords, out_coords, indices, name=None):
    return ragged_lambda(_get_relative_coords,
                         name=name)([in_coords, out_coords, indices])
示例#5
0
文件: utils.py 项目: jackd/more-keras
def row_lengths(row_splits):
    if isinstance(row_splits, tf.RaggedTensor):
        return ragged_lambda(_row_lengths)(row_splits)
    else:
        return diff(row_splits)
示例#6
0
文件: query.py 项目: jackd/deep-cloud
def reverse_query_pairs(neighbors, size):
    return ragged_lambda(_reverse_query_pairs)([neighbors, size])