Beispiel #1
0
def get_base_extractor(sizes,
                       flat_node_indices,
                       row_splits,
                       batch_row_splits_or_k,
                       edge_reduction_fn=edge.reduce_max,
                       units_scale=4,
                       unit_expansion_factor=SQRT_2):
    """
    Args:
        sizes: [K] list of int scalars used in edge_reduction_fn, size of all
            elements in the depth across the entire batch.
        flat_node_indices:
        row_splits:
        batch_row_splits_or_k:
        edge_reduction_fn: see `deep_cloud.ops.edge`.
        units_scale: number of dense units is proportional to this.
        unit_expansion_factor: rate at which the number of units increase.

    Returns:
        Sensible `KPartiteFeatureExtractor`.
    """
    local_extractors = []
    global_extractors = []
    depth = len(flat_node_indices)
    for i in range(depth):
        extractors = []
        local_extractors.append(extractors)
        for j in range(i + 1):
            # double units in one sample dimension
            # the other sample dimension increases the receptive field
            # number of ops is constant for a sample rate of 0.25
            units = int(np.round(units_scale * unit_expansion_factor**j))
            extractors.append(
                BiPartiteFeatureExtractor(flat_node_indices[i][j],
                                          row_splits[i][j],
                                          sizes[i],
                                          initial_units=units,
                                          edge_network_fn=mlp([units]),
                                          edge_reduction_fn=edge_reduction_fn,
                                          dense_factory=Dense))
        # global extractors work per-node
        # for sample rate of 0.25, doubling units per layer keeps ops constant
        units = int(np.round(2 * units_scale * unit_expansion_factor**i))
        global_extractors.append(
            GlobalBipartiteFeatureExtractor(batch_row_splits_or_k[i], units,
                                            mlp([units])))

    return KPartiteFeatureExtractor(
        local_extractors,
        global_extractors,
    )
Beispiel #2
0
def pointnet3_classifier(input_spec,
                         output_spec,
                         layer_network_lists=None,
                         global_network=None,
                         logits_network=get_logits,
                         reduction=tf.reduce_max):
    if layer_network_lists is None:
        layer_network_lists = (
            (mlp((32, 32, 64)), mlp((64, 64, 128)), mlp((64, 96, 128))),
            (mlp((64, 64, 128)), mlp((128, 128, 256)), mlp((128, 128, 256))),
        )

    if not (isinstance(layer_network_lists, (list, tuple)) and
            all(isinstance(lns, (list, tuple)) and
                all(callable(ln) for ln in lns)
            for lns in layer_network_lists)):
        raise ValueError(
            'layer_networks should be a list/tuple of list/tuples of networks'
            ' got {}'.format(layer_network_lists))
    if global_network is None:
        global_network = mlp((256, 512, 1024))

    inputs = spec.inputs(input_spec)
    (flat_normals, all_rel_coords, all_node_indices, all_row_splits,
     final_coords, outer_row_splits) = inputs
    if isinstance(final_coords, tuple):
        final_coords = final_coords[-1]

    if len(layer_network_lists) != len(all_rel_coords):
        raise ValueError(
            'Expected same number of layer networks as all_rel_coords '
            'but {} != {}'.format(
                len(layer_network_lists), len(all_rel_coords)))

    node_features = None if flat_normals == () else flat_normals
    for layer_networks, rel_coords, node_indices, row_splits in zip(
            layer_network_lists, all_rel_coords, all_node_indices,
            all_row_splits):
        node_features = [
            pointnet_block(node_features, rc, ni, rs, ln, reduction)
            for (rc, ni, rs, ln) in zip(
                rel_coords, node_indices, row_splits, layer_networks)]
        node_features = tf.concat(node_features, axis=-1)

    global_features = pointnet_block(node_features,
                                     final_coords,
                                     None,
                                     outer_row_splits,
                                     global_network,
                                     reduction=reduction)

    logits = get_logits(global_features, output_spec.shape[-1])
    inputs = tf.nest.flatten(inputs)
    return tf.keras.Model(inputs=inputs, outputs=logits)
Beispiel #3
0
def mlp_recurse(units, **kwargs):
    if all(isinstance(u, int) for u in units):
        return mlp(units, **kwargs)
    else:
        return tuple(mlp_recurse(u, **kwargs) for u in units)
Beispiel #4
0
def get_logits(features, num_classes, dropout_rate=0.5):
    return mlp((512, 256),
               final_units=num_classes,
               dropout_impl=functools.partial(Dropout,
                                              rate=dropout_rate))(features)
Beispiel #5
0
def get_base_global_network(units=(512, 256), dropout_impl=None):
    return mlp(units=units, dropout_impl=dropout_impl)
Beispiel #6
0
def generalized_classifier(input_spec,
                           output_spec,
                           coord_features_fn=get_coord_features,
                           dense_factory=mk_layers.Dense,
                           batch_norm_impl=mk_layers.BatchNormalization,
                           activation='relu',
                           global_filters=(512, 256),
                           filters0=32,
                           global_dropout_impl=None):
    batch_norm_fn = get_batch_norm(batch_norm_impl)
    activation_fn = get_activation(activation)

    inputs = spec.inputs(input_spec)
    num_classes = output_spec.shape[-1]

    # class_index = inputs.get('class_index')
    # if class_index is None:
    #     global_features = None
    # else:
    #     global_features = tf.squeeze(tf.keras.layers.Embedding(
    #         num_classes, filters0, input_lenght=1)(class_index),
    #                                  axis=1)
    (
        all_coords,
        flat_rel_coords,
        flat_node_indices,
        row_splits,
        sample_indices,
        feature_weights,
        # outer_row_splits,
    ) = (
        inputs[k] for k in (
            'all_coords',
            'flat_rel_coords',
            'flat_node_indices',
            'row_splits',
            'sample_indices',
            'feature_weights',
            # 'outer_row_splits',
        ))
    # del outer_row_splits

    depth = len(all_coords)
    coord_features = tuple(coord_features_fn(rc) for rc in flat_rel_coords)
    features = inputs.get('normals')

    filters = filters0
    if features is None:
        features = gen_layers.FeaturelessRaggedConvolution(filters)(
            [flat_rel_coords[0], flat_node_indices[0], feature_weights[0]])
    else:
        raise NotImplementedError()

    activation_kwargs = dict(batch_norm_fn=batch_norm_fn,
                             activation_fn=activation_fn)
    bottleneck_kwargs = dict(dense_factory=dense_factory, **activation_kwargs)

    features = activation_fn(batch_norm_fn(features))
    res_features = []
    for i in range(depth - 1):
        # in place
        features = blocks.in_place_bottleneck(features,
                                              coord_features[2 * i],
                                              flat_node_indices[2 * i],
                                              row_splits[2 * i],
                                              weights=feature_weights[2 * i],
                                              **bottleneck_kwargs)
        res_features.append(features)
        # down sample
        filters *= 2
        features = blocks.down_sample_bottleneck(features,
                                                 coord_features[2 * i + 1],
                                                 flat_node_indices[2 * i + 1],
                                                 row_splits[2 * i + 1],
                                                 feature_weights[2 * i + 1],
                                                 sample_indices[i],
                                                 filters=filters,
                                                 **bottleneck_kwargs)

    features = blocks.in_place_bottleneck(features,
                                          flat_rel_coords[-1],
                                          flat_node_indices[-1],
                                          row_splits[-1],
                                          feature_weights[-1],
                                          filters=filters,
                                          **bottleneck_kwargs)

    # global conv
    global_coords = all_coords[-1]
    features = gen_layers.GlobalRaggedConvolution(
        global_filters[0], dense_factory=mk_layers.Dense)(
            [features, global_coords.flat_values, global_coords.row_splits])
    logits = mlp(global_filters[1:],
                 activate_first=True,
                 final_units=num_classes,
                 batch_norm_impl=batch_norm_impl,
                 activation=activation,
                 dropout_impl=global_dropout_impl)(features)
    return tf.keras.Model(tf.nest.flatten(inputs), logits)