def global_deconv(self, global_features, coord_features, row_splits, filters_out): raise NotImplementedError('TODO') return conv.mlp_global_deconv( global_features, utils.flatten_leading_dims(coord_features, 2), row_splits, lambda features, global_features: self._global_network_factory( features, global_features, filters_out))
def global_conv(self, features, coord_features, row_splits, filters_out): return conv.mlp_edge_conv( features, utils.flatten_leading_dims(coord_features, 2), None, row_splits, lambda features: self._global_fn(features, filters_out), weights=None)
def _batched_ragged(self, rt): if rt in self._batched_inputs_dict: return self._batched_inputs_dict[rt] size = self._batched_fixed_tensor( ragged.ragged_lambda(lambda x: x.nrows())(rt)) nested_row_lengths = ragged.nested_row_lengths(rt) nested_row_lengths = [ self._batched_tensor(rl) for rl in nested_row_lengths ] nested_row_lengths = [ utils.flatten_leading_dims(rl, 2) for rl in nested_row_lengths ] values = utils.flatten_leading_dims( self._batched_tensor(rt.flat_values), 2) out = ragged.ragged_from_nested_row_lengths(values, [size] + nested_row_lengths) self._batched_inputs_dict[rt] = out return out
def global_conv(self, features, coord_features, row_splits, filters_out): features = conv.flat_expanding_edge_conv( features, utils.flatten_leading_dims(coord_features, 2), None, row_splits) if filters_out is not None: features = core.Dense(filters_out)(features) if self._global_activation is not None: features = self._global_activation(features) return features
def segmentation_logits(inputs, output_spec, num_obj_classes=16, r0=0.1, initial_filters=(16, ), initial_activation=seg_activation, filters=(32, 64, 128, 256), global_units='combined', query_fn=core.query_pairs, radii_fn=core.constant_radii, global_deconv_all=False, coords_transform=None, weights_transform=None, convolver=None): if convolver is None: convolver = c.ExpandingConvolver(activation=seg_activation) if coords_transform is None: coords_transform = t.polynomial_transformer() if weights_transform is None: def weights_transform(*args, **kwargs): return None coords = inputs['positions'] normals = inputs.get('normals') if normals is None: raise NotImplementedError() features = b.as_batched_model_input(normals) for f in initial_filters: features = tf.ragged.map_flat_values(core_layers.Dense(f), features) features = tf.ragged.map_flat_values(initial_activation, features) assert (isinstance(features, tf.RaggedTensor) and features.ragged_rank == 1) class_embeddings = _get_class_embeddings( b.as_batched_model_input(inputs['obj_label']), num_obj_classes, [initial_filters[-1], filters[0]]) features = core.add_local_global(features, class_embeddings[0]) input_row_splits = features.row_splits features = utils.flatten_leading_dims(features, 2) n_res = len(filters) unscaled_radii2 = radii_fn(n_res) if isinstance(unscaled_radii2, tf.Tensor): assert (unscaled_radii2.shape == (n_res, )) radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2) radii2 = tf.keras.layers.Lambda(tf.unstack, arguments=dict(axis=0))(radii2) for i, radius2 in enumerate(radii2): tf.compat.v1.summary.scalar('r%d' % i, tf.sqrt(radius2), family='radii') else: radii2 = unscaled_radii2 * (r0**2) def maybe_feed(r2): is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable)) if is_tensor_or_var: return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2)) else: return np.sqrt(r2) pp_radii2 = [maybe_feed(r2) for r2 in radii2] all_features = [] in_place_neighborhoods = [] sampled_neighborhoods = [] global_features = [] # encoder for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)): neighbors, sample_rate = query_fn(coords, pp_radius2, name='query%d' % i) if not isinstance(radius2, tf.Tensor): radius2 = utils.constant(radius2, dtype=tf.float32) neighborhood = n.InPlaceNeighborhood(coords, neighbors) in_place_neighborhoods.append(neighborhood) features, nested_row_splits = core.convolve(features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.in_place_conv) all_features.append(features) if global_units == 'combined': coord_features = coords_transform(neighborhood.out_coords, None) global_features.append( convolver.global_conv(features, coord_features, nested_row_splits[-2], filters[i])) global_features = tf.keras.layers.Lambda( tf.concat, arguments=dict(axis=-1))(global_features) # resample if i < n_res - 1: sample_indices = sample.sample( sample_rate, tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate)) neighborhood = n.SampledNeighborhood(neighborhood, sample_indices) sampled_neighborhoods.append(neighborhood) features, nested_row_splits = core.convolve( features, radius2, filters[i + 1], neighborhood, coords_transform, weights_transform, convolver.resample_conv) coords = neighborhood.out_coords # global_conv if global_units is not None: row_splits = nested_row_splits[-2] if global_units == 'combined': global_features = tf.keras.layers.Lambda( tf.concat, arguments=dict(axis=-1))(global_features) else: coord_features = coords_transform(coords, None) global_features = convolver.global_conv(features, coord_features, row_splits, global_units) coord_features = coords_transform(coords, None) features = convolver.global_deconv(global_features, coord_features, row_splits, filters[-1]) # decoder for i in range(n_res - 1, -1, -1): if i < n_res - 1: # up-sample neighborhood = sampled_neighborhoods.pop().transpose features, nested_row_splits = core.convolve( features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.resample_conv) if global_deconv_all: coords = neighborhood.out_coords row_splits = \ neighborhood.offset_batched_neighbors.nested_row_splits[-2] coord_features = coords_transform(coords) deconv_features = convolver.global_deconv( global_features, coord_features, row_splits, filters[i]) features = tf.keras.layers.Add()([features, deconv_features]) forward_features = all_features.pop() if not (i == n_res - 1 and global_units is None): features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))( [features, forward_features]) neighborhood = in_place_neighborhoods.pop().transpose features, nested_row_splits = core.convolve(features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.resample_conv) features = tf.RaggedTensor.from_row_splits(features, input_row_splits) features = core.add_local_global(features, class_embeddings[-1]) logits = tf.ragged.map_flat_values( core_layers.Dense(output_spec.shape[-1]), features) valid_classes_mask = inputs.get('valid_classes_mask') if valid_classes_mask is not None: row_lengths = utils.diff(logits.row_splits) valid_classes_mask = b.as_batched_model_input(valid_classes_mask) valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0) def flat_fn(flat_logits): neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits) return utils.lambda_call(tf.where, valid_classes_mask, flat_logits, neg_inf) logits = tf.ragged.map_flat_values(flat_fn, logits) return logits
def cls_head( coords, normals=None, r0=0.1, initial_filters=(16,), initial_activation=cls_head_activation, filters=(32, 64, 128, 256), global_units='combined', query_fn=core.query_pairs, radii_fn=core.constant_radii, coords_transform=None, weights_transform=None, convolver=None): if convolver is None: convolver = c.ExpandingConvolver(activation=cls_head_activation) if coords_transform is None: coords_transform = t.polynomial_transformer() if weights_transform is None: def weights_transform(*args, **kwargs): return None n_res = len(filters) unscaled_radii2 = radii_fn(n_res) if isinstance(unscaled_radii2, tf.Tensor): assert(unscaled_radii2.shape == (n_res,)) radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2) radii2 = tf.keras.layers.Lambda( tf.unstack, arguments=dict(axis=0))(radii2) for i, radius2 in enumerate(radii2): tf.compat.v1.summary.scalar( 'r%d' % i, tf.sqrt(radius2), family='radii') else: radii2 = unscaled_radii2 * (r0**2) def maybe_feed(r2): if isinstance(r2, (tf.Tensor, tf.Variable)): return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2)) else: return np.sqrt(r2) features = b.as_batched_model_input(normals) for f in initial_filters: layer = core_layers.Dense(f) features = tf.ragged.map_flat_values( layer, features) features = tf.ragged.map_flat_values(initial_activation, features) features = utils.flatten_leading_dims(features, 2) global_features = [] for i, radius2 in enumerate(radii2): neighbors, sample_rate = query_fn( coords, maybe_feed(radius2), name='query%d' % i) if not isinstance(radius2, tf.Tensor): radius2 = utils.constant(radius2, dtype=tf.float32) neighborhood = n.InPlaceNeighborhood(coords, neighbors) features, nested_row_splits = core.convolve( features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.in_place_conv) if global_units == 'combined': coord_features = coords_transform(neighborhood.out_coords, None) global_features.append(convolver.global_conv( features, coord_features, nested_row_splits[-2], filters[i])) if i < n_res - 1: sample_indices = sample.sample( sample_rate, tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate)) neighborhood = n.SampledNeighborhood(neighborhood, sample_indices) features, nested_row_splits = core.convolve( features, radius2, filters[i+1], neighborhood, coords_transform, weights_transform, convolver.resample_conv) coords = neighborhood.out_coords # global_conv if global_units == 'combined': features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))( global_features) else: coord_features = coords_transform(coords, None) features = convolver.global_conv( features, coord_features, nested_row_splits[-2], global_units) return features