def simple_mlp(x, filters_out, n_hidden=1, filters_hidden=None, hidden_activation='relu', final_activation='relu'): """ Simple multi-layer perceptron model. Args: x: [N, filters_in] float32 input features filters_out: python int, number of output filters n_hidden: python int, number of hidden layers filters_hidden: python int, number of filters in each hidden layer hidden_activation: activation applied at each hidden layer final_activation: activation applied at the end Returns: [N, filters_out] float32 output features """ if filters_hidden is None: filters_hidden = x.shape[-1] hidden_activation = _activation(hidden_activation) final_activation = _activation(final_activation) for _ in range(n_hidden): x = core.Dense(filters_hidden)(x) x = hidden_activation(x) x = core.Dense(filters_out)(x) return final_activation(x)
def cls_tail( features, num_classes, hidden_units=(), activation=cls_tail_activation): if activation is None: assert(len(hidden_units) == 0) else: features = activation(features) for u in hidden_units: features = core_layers.Dense(u, activation=None)(features) features = activation(features) logits = core_layers.Dense(num_classes, activation=None)(features) return logits
def mlp_layer( flat_features, units, activation='relu', use_batch_normalization=False, dropout_rate=None, init_kernel_scale=1.0): """ Basic multi-layer perceptron layer + call. Args: flat_features: [N, units_in] float32 non-ragged float features units: number of output features activation: used in core.Dense dropout_rate: rate used in dropout (no dropout if this is None) use_batch_normalization: applied after dropout if True init_kernel_scale: scale used in kernel_initializer. Returns: [N, units] float32 output features """ kernel_initializer = utils.scaled_glorot_uniform(scale=init_kernel_scale) flat_features = core.Dense( units, activation=activation, use_bias=not use_batch_normalization, kernel_initializer=kernel_initializer)(flat_features) if dropout_rate is not None: flat_features = core.dropout(flat_features, rate=dropout_rate) if use_batch_normalization: flat_features = core.batch_norm( flat_features, scale=(activation != 'relu')) return flat_features
def in_place_conv(self, features, coord_features, batched_neighbors, filters_out, weights): def base_conv(f): return self._base.in_place_conv(f, coord_features, batched_neighbors, filters_out, weights) x = features for _ in range(2): x = base_conv(features) x = core.batch_norm(x, scale=False) x = tf.keras.layers.Activation(self._activation)(x) x = base_conv(features) x = core.batch_norm(x) shortcut = features if self._combine == 'add': if features.shape[-1] != x.shape[-1]: shortcut = core.Dense(filters_out)(shortcut) shortcut = tf.keras.layers.BatchNormalization()(shortcut) x = tf.keras.layers.Add()([x, shortcut]) return tf.keras.layers.Activation(self._activation)(x) elif self._combine == 'concat': x = tf.keras.layers.Activation(self._activation)(x) return tf.keras.layers.Lambda( tf.concat, arguments=dict(axis=-1))([x, shortcut])
def global_conv(self, features, coord_features, row_splits, filters_out): features = conv.flat_expanding_edge_conv( features, utils.flatten_leading_dims(coord_features, 2), None, row_splits) if filters_out is not None: features = core.Dense(filters_out)(features) if self._global_activation is not None: features = self._global_activation(features) return features
def in_place_conv(self, features, coord_features, batched_neighbors, filters_out, weights): features = conv.flat_expanding_edge_conv( features, coord_features.flat_values, batched_neighbors.flat_values, batched_neighbors.nested_row_splits[-1], None if weights is None else weights.flat_values) if filters_out is not None: features = core.Dense(filters_out)(features) if self._activation is not None: features = self._activation(features) return features
def get_log_dense_exp_features(rel_coords, num_complex_features=8, max_order=3, is_total_order=True, dropout_rate=None, use_batch_norm=False): # NOTE: probably going to be gradient problems if `not is_total_order` constraints = [tf.keras.constraints.NonNeg()] if max_order is not None: if is_total_order: order_constraint = tf.keras.constraints.MaxNorm(max_order, axis=0) else: order_constraint = c.MaxValue(max_order) constraints.append(order_constraint) constraint = c.compound_constraint(*constraints) dense = core.Dense(num_complex_features, kernel_constraint=constraint, kernel_initializer=tf.initializers.truncated_normal( mean=0.5, stddev=0.2)) def f(inputs, min_log_value=-10): x = inputs x = layer_utils.lambda_call(tf.abs, x) x = layer_utils.lambda_call(tf.math.log, x) x = layer_utils.lambda_call(tf.maximum, x, min_log_value) x = dense(x) x = layer_utils.lambda_call(tf.exp, x) angle = layer_utils.lambda_call(_complex_angle, inputs) angle = dense(angle) components = layer_utils.lambda_call(_angle_to_unit_vector, angle) x = layer_utils.lambda_call(tf.expand_dims, x, axis=-1) x = layer_utils.lambda_call(tf.multiply, x, components) x = layer_utils.flatten_final_dims(x, n=2) x = layer_utils.lambda_call(_zeros_if_any_small, inputs, x) if dropout_rate is not None: x = core.dropout(x, dropout_rate) if use_batch_norm: x = core.batch_norm(x) return x return tf.ragged.map_flat_values(f, rel_coords)
def segmentation_logits(inputs, output_spec, num_obj_classes=16, r0=0.1, initial_filters=(16, ), initial_activation=seg_activation, filters=(32, 64, 128, 256), global_units='combined', query_fn=core.query_pairs, radii_fn=core.constant_radii, global_deconv_all=False, coords_transform=None, weights_transform=None, convolver=None): if convolver is None: convolver = c.ExpandingConvolver(activation=seg_activation) if coords_transform is None: coords_transform = t.polynomial_transformer() if weights_transform is None: def weights_transform(*args, **kwargs): return None coords = inputs['positions'] normals = inputs.get('normals') if normals is None: raise NotImplementedError() features = b.as_batched_model_input(normals) for f in initial_filters: features = tf.ragged.map_flat_values(core_layers.Dense(f), features) features = tf.ragged.map_flat_values(initial_activation, features) assert (isinstance(features, tf.RaggedTensor) and features.ragged_rank == 1) class_embeddings = _get_class_embeddings( b.as_batched_model_input(inputs['obj_label']), num_obj_classes, [initial_filters[-1], filters[0]]) features = core.add_local_global(features, class_embeddings[0]) input_row_splits = features.row_splits features = utils.flatten_leading_dims(features, 2) n_res = len(filters) unscaled_radii2 = radii_fn(n_res) if isinstance(unscaled_radii2, tf.Tensor): assert (unscaled_radii2.shape == (n_res, )) radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2) radii2 = tf.keras.layers.Lambda(tf.unstack, arguments=dict(axis=0))(radii2) for i, radius2 in enumerate(radii2): tf.compat.v1.summary.scalar('r%d' % i, tf.sqrt(radius2), family='radii') else: radii2 = unscaled_radii2 * (r0**2) def maybe_feed(r2): is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable)) if is_tensor_or_var: return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2)) else: return np.sqrt(r2) pp_radii2 = [maybe_feed(r2) for r2 in radii2] all_features = [] in_place_neighborhoods = [] sampled_neighborhoods = [] global_features = [] # encoder for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)): neighbors, sample_rate = query_fn(coords, pp_radius2, name='query%d' % i) if not isinstance(radius2, tf.Tensor): radius2 = utils.constant(radius2, dtype=tf.float32) neighborhood = n.InPlaceNeighborhood(coords, neighbors) in_place_neighborhoods.append(neighborhood) features, nested_row_splits = core.convolve(features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.in_place_conv) all_features.append(features) if global_units == 'combined': coord_features = coords_transform(neighborhood.out_coords, None) global_features.append( convolver.global_conv(features, coord_features, nested_row_splits[-2], filters[i])) global_features = tf.keras.layers.Lambda( tf.concat, arguments=dict(axis=-1))(global_features) # resample if i < n_res - 1: sample_indices = sample.sample( sample_rate, tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate)) neighborhood = n.SampledNeighborhood(neighborhood, sample_indices) sampled_neighborhoods.append(neighborhood) features, nested_row_splits = core.convolve( features, radius2, filters[i + 1], neighborhood, coords_transform, weights_transform, convolver.resample_conv) coords = neighborhood.out_coords # global_conv if global_units is not None: row_splits = nested_row_splits[-2] if global_units == 'combined': global_features = tf.keras.layers.Lambda( tf.concat, arguments=dict(axis=-1))(global_features) else: coord_features = coords_transform(coords, None) global_features = convolver.global_conv(features, coord_features, row_splits, global_units) coord_features = coords_transform(coords, None) features = convolver.global_deconv(global_features, coord_features, row_splits, filters[-1]) # decoder for i in range(n_res - 1, -1, -1): if i < n_res - 1: # up-sample neighborhood = sampled_neighborhoods.pop().transpose features, nested_row_splits = core.convolve( features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.resample_conv) if global_deconv_all: coords = neighborhood.out_coords row_splits = \ neighborhood.offset_batched_neighbors.nested_row_splits[-2] coord_features = coords_transform(coords) deconv_features = convolver.global_deconv( global_features, coord_features, row_splits, filters[i]) features = tf.keras.layers.Add()([features, deconv_features]) forward_features = all_features.pop() if not (i == n_res - 1 and global_units is None): features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))( [features, forward_features]) neighborhood = in_place_neighborhoods.pop().transpose features, nested_row_splits = core.convolve(features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.resample_conv) features = tf.RaggedTensor.from_row_splits(features, input_row_splits) features = core.add_local_global(features, class_embeddings[-1]) logits = tf.ragged.map_flat_values( core_layers.Dense(output_spec.shape[-1]), features) valid_classes_mask = inputs.get('valid_classes_mask') if valid_classes_mask is not None: row_lengths = utils.diff(logits.row_splits) valid_classes_mask = b.as_batched_model_input(valid_classes_mask) valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0) def flat_fn(flat_logits): neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits) return utils.lambda_call(tf.where, valid_classes_mask, flat_logits, neg_inf) logits = tf.ragged.map_flat_values(flat_fn, logits) return logits
def cls_head( coords, normals=None, r0=0.1, initial_filters=(16,), initial_activation=cls_head_activation, filters=(32, 64, 128, 256), global_units='combined', query_fn=core.query_pairs, radii_fn=core.constant_radii, coords_transform=None, weights_transform=None, convolver=None): if convolver is None: convolver = c.ExpandingConvolver(activation=cls_head_activation) if coords_transform is None: coords_transform = t.polynomial_transformer() if weights_transform is None: def weights_transform(*args, **kwargs): return None n_res = len(filters) unscaled_radii2 = radii_fn(n_res) if isinstance(unscaled_radii2, tf.Tensor): assert(unscaled_radii2.shape == (n_res,)) radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2) radii2 = tf.keras.layers.Lambda( tf.unstack, arguments=dict(axis=0))(radii2) for i, radius2 in enumerate(radii2): tf.compat.v1.summary.scalar( 'r%d' % i, tf.sqrt(radius2), family='radii') else: radii2 = unscaled_radii2 * (r0**2) def maybe_feed(r2): if isinstance(r2, (tf.Tensor, tf.Variable)): return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2)) else: return np.sqrt(r2) features = b.as_batched_model_input(normals) for f in initial_filters: layer = core_layers.Dense(f) features = tf.ragged.map_flat_values( layer, features) features = tf.ragged.map_flat_values(initial_activation, features) features = utils.flatten_leading_dims(features, 2) global_features = [] for i, radius2 in enumerate(radii2): neighbors, sample_rate = query_fn( coords, maybe_feed(radius2), name='query%d' % i) if not isinstance(radius2, tf.Tensor): radius2 = utils.constant(radius2, dtype=tf.float32) neighborhood = n.InPlaceNeighborhood(coords, neighbors) features, nested_row_splits = core.convolve( features, radius2, filters[i], neighborhood, coords_transform, weights_transform, convolver.in_place_conv) if global_units == 'combined': coord_features = coords_transform(neighborhood.out_coords, None) global_features.append(convolver.global_conv( features, coord_features, nested_row_splits[-2], filters[i])) if i < n_res - 1: sample_indices = sample.sample( sample_rate, tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate)) neighborhood = n.SampledNeighborhood(neighborhood, sample_indices) features, nested_row_splits = core.convolve( features, radius2, filters[i+1], neighborhood, coords_transform, weights_transform, convolver.resample_conv) coords = neighborhood.out_coords # global_conv if global_units == 'combined': features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))( global_features) else: coord_features = coords_transform(coords, None) features = convolver.global_conv( features, coord_features, nested_row_splits[-2], global_units) return features