예제 #1
0
def simple_mlp(x,
               filters_out,
               n_hidden=1,
               filters_hidden=None,
               hidden_activation='relu',
               final_activation='relu'):
    """
    Simple multi-layer perceptron model.

    Args:
        x: [N, filters_in] float32 input features
        filters_out: python int, number of output filters
        n_hidden: python int, number of hidden layers
        filters_hidden: python int, number of filters in each hidden layer
        hidden_activation: activation applied at each hidden layer
        final_activation: activation applied at the end

    Returns:
        [N, filters_out] float32 output features
    """
    if filters_hidden is None:
        filters_hidden = x.shape[-1]
    hidden_activation = _activation(hidden_activation)
    final_activation = _activation(final_activation)
    for _ in range(n_hidden):
        x = core.Dense(filters_hidden)(x)
        x = hidden_activation(x)
    x = core.Dense(filters_out)(x)
    return final_activation(x)
예제 #2
0
파일: cls.py 프로젝트: jackd/weighpoint
def cls_tail(
        features, num_classes, hidden_units=(),
        activation=cls_tail_activation):
    if activation is None:
        assert(len(hidden_units) == 0)
    else:
        features = activation(features)
    for u in hidden_units:
        features = core_layers.Dense(u, activation=None)(features)
        features = activation(features)
    logits = core_layers.Dense(num_classes, activation=None)(features)
    return logits
예제 #3
0
def mlp_layer(
        flat_features, units, activation='relu',
        use_batch_normalization=False,
        dropout_rate=None, init_kernel_scale=1.0):
    """
    Basic multi-layer perceptron layer + call.

    Args:
        flat_features: [N, units_in] float32 non-ragged float features
        units: number of output features
        activation: used in core.Dense
        dropout_rate: rate used in dropout (no dropout if this is None)
        use_batch_normalization: applied after dropout if True
        init_kernel_scale: scale used in kernel_initializer.

    Returns:
        [N, units] float32 output features
    """
    kernel_initializer = utils.scaled_glorot_uniform(scale=init_kernel_scale)
    flat_features = core.Dense(
        units, activation=activation,
        use_bias=not use_batch_normalization,
        kernel_initializer=kernel_initializer)(flat_features)
    if dropout_rate is not None:
        flat_features = core.dropout(flat_features, rate=dropout_rate)
    if use_batch_normalization:
        flat_features = core.batch_norm(
            flat_features, scale=(activation != 'relu'))
    return flat_features
예제 #4
0
    def in_place_conv(self, features, coord_features, batched_neighbors,
                      filters_out, weights):
        def base_conv(f):
            return self._base.in_place_conv(f, coord_features,
                                            batched_neighbors, filters_out,
                                            weights)

        x = features
        for _ in range(2):
            x = base_conv(features)
            x = core.batch_norm(x, scale=False)
            x = tf.keras.layers.Activation(self._activation)(x)
        x = base_conv(features)
        x = core.batch_norm(x)
        shortcut = features
        if self._combine == 'add':
            if features.shape[-1] != x.shape[-1]:
                shortcut = core.Dense(filters_out)(shortcut)
                shortcut = tf.keras.layers.BatchNormalization()(shortcut)
            x = tf.keras.layers.Add()([x, shortcut])
            return tf.keras.layers.Activation(self._activation)(x)
        elif self._combine == 'concat':
            x = tf.keras.layers.Activation(self._activation)(x)
            return tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))([x, shortcut])
예제 #5
0
 def global_conv(self, features, coord_features, row_splits, filters_out):
     features = conv.flat_expanding_edge_conv(
         features, utils.flatten_leading_dims(coord_features, 2), None,
         row_splits)
     if filters_out is not None:
         features = core.Dense(filters_out)(features)
     if self._global_activation is not None:
         features = self._global_activation(features)
     return features
예제 #6
0
    def in_place_conv(self, features, coord_features, batched_neighbors,
                      filters_out, weights):
        features = conv.flat_expanding_edge_conv(
            features, coord_features.flat_values,
            batched_neighbors.flat_values,
            batched_neighbors.nested_row_splits[-1],
            None if weights is None else weights.flat_values)
        if filters_out is not None:
            features = core.Dense(filters_out)(features)
        if self._activation is not None:
            features = self._activation(features)

        return features
예제 #7
0
파일: lde.py 프로젝트: jackd/weighpoint
def get_log_dense_exp_features(rel_coords,
                               num_complex_features=8,
                               max_order=3,
                               is_total_order=True,
                               dropout_rate=None,
                               use_batch_norm=False):
    # NOTE: probably going to be gradient problems if `not is_total_order`
    constraints = [tf.keras.constraints.NonNeg()]
    if max_order is not None:
        if is_total_order:
            order_constraint = tf.keras.constraints.MaxNorm(max_order, axis=0)
        else:
            order_constraint = c.MaxValue(max_order)
        constraints.append(order_constraint)
    constraint = c.compound_constraint(*constraints)
    dense = core.Dense(num_complex_features,
                       kernel_constraint=constraint,
                       kernel_initializer=tf.initializers.truncated_normal(
                           mean=0.5, stddev=0.2))

    def f(inputs, min_log_value=-10):
        x = inputs
        x = layer_utils.lambda_call(tf.abs, x)
        x = layer_utils.lambda_call(tf.math.log, x)
        x = layer_utils.lambda_call(tf.maximum, x, min_log_value)
        x = dense(x)
        x = layer_utils.lambda_call(tf.exp, x)
        angle = layer_utils.lambda_call(_complex_angle, inputs)
        angle = dense(angle)
        components = layer_utils.lambda_call(_angle_to_unit_vector, angle)
        x = layer_utils.lambda_call(tf.expand_dims, x, axis=-1)
        x = layer_utils.lambda_call(tf.multiply, x, components)
        x = layer_utils.flatten_final_dims(x, n=2)
        x = layer_utils.lambda_call(_zeros_if_any_small, inputs, x)
        if dropout_rate is not None:
            x = core.dropout(x, dropout_rate)
        if use_batch_norm:
            x = core.batch_norm(x)
        return x

    return tf.ragged.map_flat_values(f, rel_coords)
예제 #8
0
파일: seg.py 프로젝트: jackd/weighpoint
def segmentation_logits(inputs,
                        output_spec,
                        num_obj_classes=16,
                        r0=0.1,
                        initial_filters=(16, ),
                        initial_activation=seg_activation,
                        filters=(32, 64, 128, 256),
                        global_units='combined',
                        query_fn=core.query_pairs,
                        radii_fn=core.constant_radii,
                        global_deconv_all=False,
                        coords_transform=None,
                        weights_transform=None,
                        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=seg_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:

        def weights_transform(*args, **kwargs):
            return None

    coords = inputs['positions']
    normals = inputs.get('normals')

    if normals is None:
        raise NotImplementedError()
    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        features = tf.ragged.map_flat_values(core_layers.Dense(f), features)
        features = tf.ragged.map_flat_values(initial_activation, features)
    assert (isinstance(features, tf.RaggedTensor)
            and features.ragged_rank == 1)

    class_embeddings = _get_class_embeddings(
        b.as_batched_model_input(inputs['obj_label']), num_obj_classes,
        [initial_filters[-1], filters[0]])

    features = core.add_local_global(features, class_embeddings[0])

    input_row_splits = features.row_splits
    features = utils.flatten_leading_dims(features, 2)

    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert (unscaled_radii2.shape == (n_res, ))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(tf.unstack,
                                        arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar('r%d' % i,
                                        tf.sqrt(radius2),
                                        family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        is_tensor_or_var = isinstance(r2, (tf.Tensor, tf.Variable))
        if is_tensor_or_var:
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    pp_radii2 = [maybe_feed(r2) for r2 in radii2]

    all_features = []
    in_place_neighborhoods = []
    sampled_neighborhoods = []
    global_features = []
    # encoder
    for i, (radius2, pp_radius2) in enumerate(zip(radii2, pp_radii2)):
        neighbors, sample_rate = query_fn(coords,
                                          pp_radius2,
                                          name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        in_place_neighborhoods.append(neighborhood)
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.in_place_conv)

        all_features.append(features)

        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(
                convolver.global_conv(features, coord_features,
                                      nested_row_splits[-2], filters[i]))
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)

        # resample
        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            sampled_neighborhoods.append(neighborhood)

            features, nested_row_splits = core.convolve(
                features, radius2, filters[i + 1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units is not None:
        row_splits = nested_row_splits[-2]
        if global_units == 'combined':
            global_features = tf.keras.layers.Lambda(
                tf.concat, arguments=dict(axis=-1))(global_features)
        else:
            coord_features = coords_transform(coords, None)
            global_features = convolver.global_conv(features, coord_features,
                                                    row_splits, global_units)

        coord_features = coords_transform(coords, None)
        features = convolver.global_deconv(global_features, coord_features,
                                           row_splits, filters[-1])

    # decoder
    for i in range(n_res - 1, -1, -1):
        if i < n_res - 1:
            # up-sample
            neighborhood = sampled_neighborhoods.pop().transpose
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i], neighborhood, coords_transform,
                weights_transform, convolver.resample_conv)
            if global_deconv_all:
                coords = neighborhood.out_coords
                row_splits = \
                    neighborhood.offset_batched_neighbors.nested_row_splits[-2]
                coord_features = coords_transform(coords)
                deconv_features = convolver.global_deconv(
                    global_features, coord_features, row_splits, filters[i])
                features = tf.keras.layers.Add()([features, deconv_features])

        forward_features = all_features.pop()
        if not (i == n_res - 1 and global_units is None):
            features = tf.keras.layers.Lambda(tf.concat,
                                              arguments=dict(axis=-1))(
                                                  [features, forward_features])
        neighborhood = in_place_neighborhoods.pop().transpose
        features, nested_row_splits = core.convolve(features, radius2,
                                                    filters[i], neighborhood,
                                                    coords_transform,
                                                    weights_transform,
                                                    convolver.resample_conv)

    features = tf.RaggedTensor.from_row_splits(features, input_row_splits)
    features = core.add_local_global(features, class_embeddings[-1])
    logits = tf.ragged.map_flat_values(
        core_layers.Dense(output_spec.shape[-1]), features)

    valid_classes_mask = inputs.get('valid_classes_mask')
    if valid_classes_mask is not None:
        row_lengths = utils.diff(logits.row_splits)
        valid_classes_mask = b.as_batched_model_input(valid_classes_mask)
        valid_classes_mask = repeat(valid_classes_mask, row_lengths, axis=0)

        def flat_fn(flat_logits):
            neg_inf = tf.keras.layers.Lambda(_neg_inf_like)(flat_logits)
            return utils.lambda_call(tf.where, valid_classes_mask, flat_logits,
                                     neg_inf)

        logits = tf.ragged.map_flat_values(flat_fn, logits)
    return logits
예제 #9
0
파일: cls.py 프로젝트: jackd/weighpoint
def cls_head(
        coords, normals=None, r0=0.1,
        initial_filters=(16,), initial_activation=cls_head_activation,
        filters=(32, 64, 128, 256),
        global_units='combined', query_fn=core.query_pairs,
        radii_fn=core.constant_radii,
        coords_transform=None,
        weights_transform=None,
        convolver=None):

    if convolver is None:
        convolver = c.ExpandingConvolver(activation=cls_head_activation)
    if coords_transform is None:
        coords_transform = t.polynomial_transformer()
    if weights_transform is None:
        def weights_transform(*args, **kwargs):
            return None
    n_res = len(filters)
    unscaled_radii2 = radii_fn(n_res)

    if isinstance(unscaled_radii2, tf.Tensor):
        assert(unscaled_radii2.shape == (n_res,))
        radii2 = utils.lambda_call(tf.math.scalar_mul, r0**2, unscaled_radii2)
        radii2 = tf.keras.layers.Lambda(
            tf.unstack, arguments=dict(axis=0))(radii2)
        for i, radius2 in enumerate(radii2):
            tf.compat.v1.summary.scalar(
                'r%d' % i, tf.sqrt(radius2), family='radii')
    else:
        radii2 = unscaled_radii2 * (r0**2)

    def maybe_feed(r2):
        if isinstance(r2, (tf.Tensor, tf.Variable)):
            return b.prebatch_feed(tf.keras.layers.Lambda(tf.sqrt)(radius2))
        else:
            return np.sqrt(r2)

    features = b.as_batched_model_input(normals)
    for f in initial_filters:
        layer = core_layers.Dense(f)
        features = tf.ragged.map_flat_values(
            layer, features)
        features = tf.ragged.map_flat_values(initial_activation, features)

    features = utils.flatten_leading_dims(features, 2)
    global_features = []

    for i, radius2 in enumerate(radii2):
        neighbors, sample_rate = query_fn(
            coords, maybe_feed(radius2), name='query%d' % i)
        if not isinstance(radius2, tf.Tensor):
            radius2 = utils.constant(radius2, dtype=tf.float32)
        neighborhood = n.InPlaceNeighborhood(coords, neighbors)
        features, nested_row_splits = core.convolve(
            features, radius2, filters[i], neighborhood, coords_transform,
            weights_transform, convolver.in_place_conv)
        if global_units == 'combined':
            coord_features = coords_transform(neighborhood.out_coords, None)
            global_features.append(convolver.global_conv(
                features, coord_features, nested_row_splits[-2], filters[i]))

        if i < n_res - 1:
            sample_indices = sample.sample(
                sample_rate,
                tf.keras.layers.Lambda(lambda s: tf.size(s) // 4)(sample_rate))
            neighborhood = n.SampledNeighborhood(neighborhood, sample_indices)
            features, nested_row_splits = core.convolve(
                features, radius2, filters[i+1], neighborhood,
                coords_transform, weights_transform, convolver.resample_conv)

            coords = neighborhood.out_coords

    # global_conv
    if global_units == 'combined':
        features = tf.keras.layers.Lambda(tf.concat, arguments=dict(axis=-1))(
            global_features)
    else:
        coord_features = coords_transform(coords, None)
        features = convolver.global_conv(
            features, coord_features, nested_row_splits[-2], global_units)

    return features