Example #1
0
 def offset_batched_neighbors(self):
     if self._offset_batched_neighbors is None:
         batched_sample_indices = b.as_batched_model_input(
             self._sample_indices)
         self._offset_batched_neighbors = utils.map_gather(
             self._base.offset_batched_neighbors, batched_sample_indices)
     return self._offset_batched_neighbors
Example #2
0
 def offset_batched_neighbors(self):
     if self._offset_batched_neighbors is None:
         batched_neighbors = b.as_batched_model_input(self.neighbors)
         offset = utils.get_row_offsets(batched_neighbors)
         self._offset_batched_neighbors = utils.apply_row_offset(
             batched_neighbors, offset)
     return self._offset_batched_neighbors
Example #3
0
 def __call__(self, x, max_radius2):
     key = (x, max_radius2)
     if key not in self._cache:
         xb = b.as_batched_model_input(x)
         value = tf.ragged.map_flat_values(
             lambda x: self._fn(x, max_radius2), xb)
         self._cache[key] = value
     return self._cache[key]
Example #4
0
def segmentation_logits(inputs,
                        output_spec,
                        num_obj_classes=16,
                        r0=0.1,
                        initial_filters=(16, ),
                        initial_activation=seg_activation,
                        filters=(32, 64, 128, 256),
                        global_units='combined',
                        query_fn=core.query_pairs,
                        radii_fn=core.constant_radii,
                        global_deconv_all=False,
                        coords_transform=None,
                        weights_transform=None,
                        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=seg_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:

        def weights_transform(*args, **kwargs):
            return None

    coords = inputs['positions']
    normals = inputs.get('normals')

    if normals is None:
        raise NotImplementedError()
    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        features = tf.ragged.map_flat_values(core_layers.Dense(f), features)
        features = tf.ragged.map_flat_values(initial_activation, features)
    assert (isinstance(features, tf.RaggedTensor)
            and features.ragged_rank == 1)

    class_embeddings = _get_class_embeddings(
        b.as_batched_model_input(inputs['obj_label']), num_obj_classes,
        [initial_filters[-1], filters[0]])

    features = core.add_local_global(features, class_embeddings[0])

    input_row_splits = features.row_splits
    features = utils.flatten_leading_dims(features, 2)

    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert (unscaled_radii2.shape == (n_res, ))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(tf.unstack,
                                        arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar('r%d' % i,
                                        tf.sqrt(radius2),
                                        family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable))
        if is_tensor_or_var:
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    pp_radii2 = [maybe_feed(r2) for r2 in radii2]

    all_features = []
    in_place_neighborhoods = []
    sampled_neighborhoods = []
    global_features = []
    # encoder
    for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)):
        neighbors, sample_rate = query_fn(coords,
                                          pp_radius2,
                                          name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        in_place_neighborhoods.append(neighborhood)
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.in_place_conv)

        all_features.append(features)

        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(
                convolver.global_conv(features, coord_features,
                                      nested_row_splits[-2], filters[i]))
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)

        # resample
        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            sampled_neighborhoods.append(neighborhood)

            features, nested_row_splits = core.convolve(
                features, radius2, filters[i + 1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units is not None:
        row_splits = nested_row_splits[-2]
        if global_units == 'combined':
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)
        else:
            coord_features = coords_transform(coords, None)
            global_features = convolver.global_conv(features, coord_features,
                                                    row_splits, global_units)

        coord_features = coords_transform(coords, None)
        features = convolver.global_deconv(global_features, coord_features,
                                           row_splits, filters[-1])

    # decoder
    for i in range(n_res - 1, -1, -1):
        if i < n_res - 1:
            # up-sample
            neighborhood = sampled_neighborhoods.pop().transpose
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i], neighborhood, coords_transform,
                weights_transform, convolver.resample_conv)
            if global_deconv_all:
                coords = neighborhood.out_coords
                row_splits = \
                    neighborhood.offset_batched_neighbors.nested_row_splits[-2]
                coord_features = coords_transform(coords)
                deconv_features = convolver.global_deconv(
                    global_features, coord_features, row_splits, filters[i])
                features = tf.keras.layers.Add()([features, deconv_features])

        forward_features = all_features.pop()
        if not (i == n_res - 1 and global_units is None):
            features = tf.keras.layers.Lambda(tf.concat,
                                              arguments=dict(axis=-1))(
                                                  [features, forward_features])
        neighborhood = in_place_neighborhoods.pop().transpose
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.resample_conv)

    features = tf.RaggedTensor.from_row_splits(features, input_row_splits)
    features = core.add_local_global(features, class_embeddings[-1])
    logits = tf.ragged.map_flat_values(
        core_layers.Dense(output_spec.shape[-1]), features)

    valid_classes_mask = inputs.get('valid_classes_mask')
    if valid_classes_mask is not None:
        row_lengths = utils.diff(logits.row_splits)
        valid_classes_mask = b.as_batched_model_input(valid_classes_mask)
        valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0)

        def flat_fn(flat_logits):
            neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits)
            return utils.lambda_call(tf.where, valid_classes_mask, flat_logits,
                                     neg_inf)

        logits = tf.ragged.map_flat_values(flat_fn, logits)
    return logits
Example #5
0
 def __call__(self, x, max_radius2):
     batched_x = b.as_batched_model_input(x)
     x_features = tf.ragged.map_flat_values(
         lambda f: self._fn(f, max_radius2), batched_x)
     return x_features
Example #6
0
    def _test_transposed_consistent(self):
        batch_size = 3

        with tf.device('/cpu:0'):
            np_data = [
                np.random.uniform(size=(100, 3)).astype(np.float32),
                np.random.uniform(size=(200, 3)).astype(np.float32),
                np.random.uniform(size=(50, 3)).astype(np.float32),
            ]

            def generator():

                labels = np.array([0, 1, 2], dtype=np.int64)
                indices = [
                    np.array([0, 5, 2, 7, 10], dtype=np.int64),
                    np.array([1, 4, 3, 2], dtype=np.int64),
                    np.array([10, 15], dtype=np.int64),
                ]
                yield (np_data[0], indices[0]), labels[0]
                yield (np_data[1], indices[1]), labels[1]
                yield (np_data[2], indices[2]), labels[2]

            dataset = tf.data.Dataset.from_generator(
                generator,
                output_types=((tf.float32, tf.int64), tf.int64),
                output_shapes=((tf.TensorShape(
                    (None, 3)), tf.TensorShape((None, ))), tf.TensorShape(())))

            coords = b.prebatch_input((None, 3), tf.float32)
            sample_indices = b.prebatch_input((None, ), dtype=tf.int64)
            neighbors = q.query_pairs(coords, 0.1)
            in_place_neighborhood = n.InPlaceNeighborhood(coords, neighbors)
            sampled_neighborhood = n.SampledNeighborhood(
                in_place_neighborhood, sample_indices)

            transposed = sampled_neighborhood.transpose
            simple = n.Neighborhood(sampled_neighborhood.out_coords,
                                    sampled_neighborhood.in_coords,
                                    transposed.neighbors)

            simple_obn = simple.offset_batched_neighbors
            trans_obn = transposed.offset_batched_neighbors
            simple_out = b.as_batched_model_input(simple.in_coords)
            trans_out = b.as_batched_model_input(transposed.in_coords)

            out = [[o.flat_values, o.nested_row_splits]
                   for o in (simple_obn, trans_obn, simple_out, trans_out)]
            flat_out = tf.nest.flatten(out)

            model = b.model(flat_out)
            preprocessor = b.preprocessor()

            dataset = preprocessor.map_and_batch(dataset, batch_size)
            # if tf.executing_eagerly():
            #     for data, label in dataset.take(1):
            #         pass
            # else:
            data, label = tf.compat.v1.data.make_one_shot_iterator(
                dataset).get_next()

            flat_out = model(tf.nest.flatten(data))
            if not isinstance(flat_out, (list, tuple)):
                flat_out = flat_out,
            out = tf.nest.pack_sequence_as(out, flat_out)

            out = [tf.RaggedTensor.from_nested_row_splits(*o) for o in out]
            out, label = self.evaluate((out, label))
            simple_obn, trans_obn, simple_coords, trans_coords = out
            self.assertRaggedEqual(simple_obn, trans_obn)
            self.assertRaggedEqual(simple_coords, trans_coords)
            self.assertEqual(simple_obn.nested_row_splits[-1][-1],
                             sum(d.shape[0] for d in np_data))
Example #7
0
def cls_head(
        coords, normals=None, r0=0.1,
        initial_filters=(16,), initial_activation=cls_head_activation,
        filters=(32, 64, 128, 256),
        global_units='combined', query_fn=core.query_pairs,
        radii_fn=core.constant_radii,
        coords_transform=None,
        weights_transform=None,
        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=cls_head_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:
        def weights_transform(*args, **kwargs):
            return None
    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert(unscaled_radii2.shape == (n_res,))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(
            tf.unstack, arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar(
                'r%d' % i, tf.sqrt(radius2), family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        if isinstance(r2, (tf.Tensor, tf.Variable)):
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        layer = core_layers.Dense(f)
        features = tf.ragged.map_flat_values(
            layer, features)
        features = tf.ragged.map_flat_values(initial_activation, features)

    features = utils.flatten_leading_dims(features, 2)
    global_features = []

    for i, radius2 in enumerate(radii2):
        neighbors, sample_rate = query_fn(
            coords, maybe_feed(radius2), name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        features, nested_row_splits = core.convolve(
            features, radius2, filters[i], neighborhood, coords_transform,
            weights_transform, convolver.in_place_conv)
        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(convolver.global_conv(
                features, coord_features, nested_row_splits[-2], filters[i]))

        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i+1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units == 'combined':
        features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))(
            global_features)
    else:
        coord_features = coords_transform(coords, None)
        features = convolver.global_conv(
            features, coord_features, nested_row_splits[-2], global_units)

    return features