예제 #1
0
def fixed_budget_radii(num_res):
    """
    Learnable unscaled squared radii based on a fixed computational budget.

    The budget assumes point cloud density quarters at each resolution, and
    aims to keep the sum across all resolutions of the sum of all neighborhood
    sizes roughly constante. Note there are generally more filters in later
    resolutions, so the number of floating point operations may increase if the
    model learns to expand final layer radii at the expense of lower
    resolutions.

    Args:
        num_res: number of resolutions.

    Returns:
        [num_res] float32 tensor of squared unscaled radii values, initially
        the same as `constant_radii(num_res)`.
    """
    r20 = constant_radii(num_res)
    radii = utils.variable(
        shape=(num_res,), dtype=tf.float32,
        initializer=tf.constant_initializer(np.sqrt(r20)))
    check = tf.debugging.assert_all_finite(radii, 'radii finite')
    if not tf.executing_eagerly() and check is not None:
        with tf.compat.v1.control_dependencies([check]):
            radii = tf.keras.layers.Lambda(tf.identity)(radii)
    radii2 = tf.keras.layers.Lambda(tf.square)(radii)
    density = np.power(4.0, -np.arange(num_res)).astype(np.float32)
    budget = utils.lambda_call(_cost, r20, density)
    return utils.lambda_call(_rescale_for_budget, radii2, density, budget)
예제 #2
0
파일: seg.py 프로젝트: jackd/weighpoint
def _get_class_embeddings(class_indices, num_obj_classes, filters):
    obj_class_embedder = tf.keras.layers.Embedding(num_obj_classes,
                                                   output_dim=sum(filters),
                                                   input_length=1)
    class_features = obj_class_embedder(
        utils.lambda_call(tf.expand_dims, class_indices, axis=-1))
    class_features = utils.lambda_call(tf.squeeze, class_features, axis=1)
    class_features = utils.lambda_call(tf.split,
                                       class_features,
                                       filters,
                                       axis=-1)
    return class_features
예제 #3
0
def mlp_global_deconv(
        global_features, coord_features, row_splits_or_k, network_fn):
    global_features = utils.lambda_call(
        _expand_and_tile, global_features, row_splits_or_k)
    features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))([
        global_features, coord_features])
    return network_fn(features)
예제 #4
0
def query_pairs(
        coords, radius, name=None, max_neighbors=None,
        sample_rate_reciprocal_offset=0):
    """
    Get neighbors and inverse density sample rates from the provided coords.

    sample_rate is given by
    sample_rate = 1. / (sum(num_neighbors) - sample_rate_reciprocal_offset)

    Args:
        coords: [n, num_dims] float32 point cloud coordinates
        radius: float scalar ball search radius
        name: used in layers
        max_neighbors: if not None, neighborhoods are cropped to this size.
        sample_rate_reciprocal_offset: affects sample rate. Values close to 1
            result in sparse neighborhoods being selected more frequently

    Returns:
        neighbors: [n, ?] int32 ragged tensor of neighboring indices
        sample_rate: sampling rate based roughly on inverse density.
    """
    neighbors = _q.query_pairs(
        coords, radius, name=name, max_neighbors=max_neighbors)
    sample_rate_reciprocal = utils.cast(
        utils.row_lengths(neighbors), tf.float32)
    if sample_rate_reciprocal_offset != 0:
        sample_rate_reciprocal = utils.lambda_call(
            tf.math.subtract, sample_rate_reciprocal,
            sample_rate_reciprocal_offset)
    sample_rate = utils.reciprocal(sample_rate_reciprocal)
    return neighbors, sample_rate
예제 #5
0
    def prebatch_feed(self, tensor):
        """
        Denote a learnable tensor as being used in the prebatch mapping.

        This allows learned model parameters (or tensors derived from them) to
        be used in the preprocessing. The resulting mapped datasets need to be
        iterated over by an reinitializable iterator and the values used will
        only be updated after each reinitialization.

        Gradients will not be propagated through the batching/mapping process.

        For example, weighpoint convolutions have a weighting function with
        associated root which doubles as the ball-search radius. We can learn
        the weighting function in the network as normal and use the resulting
        root to adapt our ball search radius.
        """
        if tensor in self._prebatch_feed_dict:
            return self._prebatch_feed_dict[tensor]
        self._mark(tensor, Marks.MODEL)
        inp = tf.keras.layers.Input(shape=tensor.shape, dtype=tensor.dtype)
        # inp = tf.keras.layers.Input(
        #     tensor=utils.lambda_call(tf.expand_dims, tensor, axis=0))
        out = utils.lambda_call(tf.squeeze, inp, axis=0)
        self._mark(out, Marks.PREBATCH, recursive=False)
        self._prebatch_feed_dict[tensor] = inp
        return out
예제 #6
0
 def prebatch_input(self, shape, dtype, name=None):
     inp = tf.keras.layers.Input(shape=shape,
                                 dtype=dtype,
                                 batch_size=1,
                                 name=name)
     self._prebatch_inputs.append(inp)
     inp = utils.lambda_call(tf.squeeze, inp, axis=0)
     self._mark(inp, Marks.PREBATCH)
     return inp
예제 #7
0
파일: lde.py 프로젝트: jackd/weighpoint
 def f(inputs, min_log_value=-10):
     x = inputs
     x = layer_utils.lambda_call(tf.abs, x)
     x = layer_utils.lambda_call(tf.math.log, x)
     x = layer_utils.lambda_call(tf.maximum, x, min_log_value)
     x = dense(x)
     x = layer_utils.lambda_call(tf.exp, x)
     angle = layer_utils.lambda_call(_complex_angle, inputs)
     angle = dense(angle)
     components = layer_utils.lambda_call(_angle_to_unit_vector, angle)
     x = layer_utils.lambda_call(tf.expand_dims, x, axis=-1)
     x = layer_utils.lambda_call(tf.multiply, x, components)
     x = layer_utils.flatten_final_dims(x, n=2)
     x = layer_utils.lambda_call(_zeros_if_any_small, inputs, x)
     if dropout_rate is not None:
         x = core.dropout(x, dropout_rate)
     if use_batch_norm:
         x = core.batch_norm(x)
     return x
예제 #8
0
파일: seg.py 프로젝트: jackd/weighpoint
def segmentation_logits(inputs,
                        output_spec,
                        num_obj_classes=16,
                        r0=0.1,
                        initial_filters=(16, ),
                        initial_activation=seg_activation,
                        filters=(32, 64, 128, 256),
                        global_units='combined',
                        query_fn=core.query_pairs,
                        radii_fn=core.constant_radii,
                        global_deconv_all=False,
                        coords_transform=None,
                        weights_transform=None,
                        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=seg_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:

        def weights_transform(*args, **kwargs):
            return None

    coords = inputs['positions']
    normals = inputs.get('normals')

    if normals is None:
        raise NotImplementedError()
    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        features = tf.ragged.map_flat_values(core_layers.Dense(f), features)
        features = tf.ragged.map_flat_values(initial_activation, features)
    assert (isinstance(features, tf.RaggedTensor)
            and features.ragged_rank == 1)

    class_embeddings = _get_class_embeddings(
        b.as_batched_model_input(inputs['obj_label']), num_obj_classes,
        [initial_filters[-1], filters[0]])

    features = core.add_local_global(features, class_embeddings[0])

    input_row_splits = features.row_splits
    features = utils.flatten_leading_dims(features, 2)

    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert (unscaled_radii2.shape == (n_res, ))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(tf.unstack,
                                        arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar('r%d' % i,
                                        tf.sqrt(radius2),
                                        family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable))
        if is_tensor_or_var:
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    pp_radii2 = [maybe_feed(r2) for r2 in radii2]

    all_features = []
    in_place_neighborhoods = []
    sampled_neighborhoods = []
    global_features = []
    # encoder
    for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)):
        neighbors, sample_rate = query_fn(coords,
                                          pp_radius2,
                                          name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        in_place_neighborhoods.append(neighborhood)
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.in_place_conv)

        all_features.append(features)

        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(
                convolver.global_conv(features, coord_features,
                                      nested_row_splits[-2], filters[i]))
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)

        # resample
        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            sampled_neighborhoods.append(neighborhood)

            features, nested_row_splits = core.convolve(
                features, radius2, filters[i + 1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units is not None:
        row_splits = nested_row_splits[-2]
        if global_units == 'combined':
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)
        else:
            coord_features = coords_transform(coords, None)
            global_features = convolver.global_conv(features, coord_features,
                                                    row_splits, global_units)

        coord_features = coords_transform(coords, None)
        features = convolver.global_deconv(global_features, coord_features,
                                           row_splits, filters[-1])

    # decoder
    for i in range(n_res - 1, -1, -1):
        if i < n_res - 1:
            # up-sample
            neighborhood = sampled_neighborhoods.pop().transpose
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i], neighborhood, coords_transform,
                weights_transform, convolver.resample_conv)
            if global_deconv_all:
                coords = neighborhood.out_coords
                row_splits = \
                    neighborhood.offset_batched_neighbors.nested_row_splits[-2]
                coord_features = coords_transform(coords)
                deconv_features = convolver.global_deconv(
                    global_features, coord_features, row_splits, filters[i])
                features = tf.keras.layers.Add()([features, deconv_features])

        forward_features = all_features.pop()
        if not (i == n_res - 1 and global_units is None):
            features = tf.keras.layers.Lambda(tf.concat,
                                              arguments=dict(axis=-1))(
                                                  [features, forward_features])
        neighborhood = in_place_neighborhoods.pop().transpose
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.resample_conv)

    features = tf.RaggedTensor.from_row_splits(features, input_row_splits)
    features = core.add_local_global(features, class_embeddings[-1])
    logits = tf.ragged.map_flat_values(
        core_layers.Dense(output_spec.shape[-1]), features)

    valid_classes_mask = inputs.get('valid_classes_mask')
    if valid_classes_mask is not None:
        row_lengths = utils.diff(logits.row_splits)
        valid_classes_mask = b.as_batched_model_input(valid_classes_mask)
        valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0)

        def flat_fn(flat_logits):
            neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits)
            return utils.lambda_call(tf.where, valid_classes_mask, flat_logits,
                                     neg_inf)

        logits = tf.ragged.map_flat_values(flat_fn, logits)
    return logits
예제 #9
0
파일: seg.py 프로젝트: jackd/weighpoint
 def flat_fn(flat_logits):
     neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits)
     return utils.lambda_call(tf.where, valid_classes_mask, flat_logits,
                              neg_inf)
예제 #10
0
def flat_expanding_edge_conv(
        node_features, coord_features, indices, row_splits_or_k, weights=None):
    features = utils.lambda_call(
        conv_ops.flat_expanding_edge_conv, node_features, coord_features,
        indices, row_splits_or_k, weights)
    return features
예제 #11
0
def reduce_flat_mean(x, row_splits_or_k, weights, eps=1e-7):
    return utils.lambda_call(
        conv_ops.reduce_flat_mean, x, row_splits_or_k, weights, eps=eps)
예제 #12
0
def flat_expanding_global_deconv(
        global_features, coord_features, row_splits_or_k):
    features = utils.lambda_call(
        conv_ops.flat_expanding_global_deconv, global_features, coord_features,
        row_splits_or_k)
    return features
예제 #13
0
파일: cls.py 프로젝트: jackd/weighpoint
def cls_head(
        coords, normals=None, r0=0.1,
        initial_filters=(16,), initial_activation=cls_head_activation,
        filters=(32, 64, 128, 256),
        global_units='combined', query_fn=core.query_pairs,
        radii_fn=core.constant_radii,
        coords_transform=None,
        weights_transform=None,
        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=cls_head_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:
        def weights_transform(*args, **kwargs):
            return None
    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert(unscaled_radii2.shape == (n_res,))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(
            tf.unstack, arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar(
                'r%d' % i, tf.sqrt(radius2), family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        if isinstance(r2, (tf.Tensor, tf.Variable)):
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        layer = core_layers.Dense(f)
        features = tf.ragged.map_flat_values(
            layer, features)
        features = tf.ragged.map_flat_values(initial_activation, features)

    features = utils.flatten_leading_dims(features, 2)
    global_features = []

    for i, radius2 in enumerate(radii2):
        neighbors, sample_rate = query_fn(
            coords, maybe_feed(radius2), name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        features, nested_row_splits = core.convolve(
            features, radius2, filters[i], neighborhood, coords_transform,
            weights_transform, convolver.in_place_conv)
        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(convolver.global_conv(
                features, coord_features, nested_row_splits[-2], filters[i]))

        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i+1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units == 'combined':
        features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))(
            global_features)
    else:
        coord_features = coords_transform(coords, None)
        features = convolver.global_conv(
            features, coord_features, nested_row_splits[-2], global_units)

    return features