Ejemplo n.º 1
0
def resnet_model(input_spec, output_spec, logits_fn=semantic_segmenter_logits):
    from more_keras import spec
    inputs = spec.inputs(input_spec)
    logits = logits_fn(inputs, output_spec.shape[-1])

    model = tf.keras.models.Model(tf.nest.flatten(inputs), outputs=logits)
    return model
Ejemplo n.º 2
0
def pointnet3_classifier(input_spec,
                         output_spec,
                         layer_network_lists=None,
                         global_network=None,
                         logits_network=get_logits,
                         reduction=tf.reduce_max):
    if layer_network_lists is None:
        layer_network_lists = (
            (mlp((32, 32, 64)), mlp((64, 64, 128)), mlp((64, 96, 128))),
            (mlp((64, 64, 128)), mlp((128, 128, 256)), mlp((128, 128, 256))),
        )

    if not (isinstance(layer_network_lists, (list, tuple)) and
            all(isinstance(lns, (list, tuple)) and
                all(callable(ln) for ln in lns)
            for lns in layer_network_lists)):
        raise ValueError(
            'layer_networks should be a list/tuple of list/tuples of networks'
            ' got {}'.format(layer_network_lists))
    if global_network is None:
        global_network = mlp((256, 512, 1024))

    inputs = spec.inputs(input_spec)
    (flat_normals, all_rel_coords, all_node_indices, all_row_splits,
     final_coords, outer_row_splits) = inputs
    if isinstance(final_coords, tuple):
        final_coords = final_coords[-1]

    if len(layer_network_lists) != len(all_rel_coords):
        raise ValueError(
            'Expected same number of layer networks as all_rel_coords '
            'but {} != {}'.format(
                len(layer_network_lists), len(all_rel_coords)))

    node_features = None if flat_normals == () else flat_normals
    for layer_networks, rel_coords, node_indices, row_splits in zip(
            layer_network_lists, all_rel_coords, all_node_indices,
            all_row_splits):
        node_features = [
            pointnet_block(node_features, rc, ni, rs, ln, reduction)
            for (rc, ni, rs, ln) in zip(
                rel_coords, node_indices, row_splits, layer_networks)]
        node_features = tf.concat(node_features, axis=-1)

    global_features = pointnet_block(node_features,
                                     final_coords,
                                     None,
                                     outer_row_splits,
                                     global_network,
                                     reduction=reduction)

    logits = get_logits(global_features, output_spec.shape[-1])
    inputs = tf.nest.flatten(inputs)
    return tf.keras.Model(inputs=inputs, outputs=logits)
Ejemplo n.º 3
0
def very_dense_classifier(input_spec,
                          output_spec,
                          dense_factory=Dense,
                          features_fn=very_dense_features):
    num_classes = output_spec.shape[-1]
    inputs = spec.inputs(input_spec)
    node_features, edge_features, global_features = features_fn(inputs)
    del node_features, edge_features
    preds = []
    for gf in global_features:
        if gf is not None:
            preds.append(dense_factory(num_classes, activation=None)(gf))

    return tf.keras.Model(inputs=tf.nest.flatten(inputs), outputs=preds)
Ejemplo n.º 4
0
def very_dense_semantic_segmenter(input_spec,
                                  output_spec,
                                  dense_factory=Dense,
                                  features_fn=very_dense_features):
    num_classes = output_spec.shape[-1]
    inputs = spec.inputs(input_spec)
    class_masks = inputs.pop('class_masks', None)
    node_features, edge_features, global_features = features_fn(inputs)
    del edge_features, global_features
    node_features = [nf[0] for nf in node_features]  # high res features
    preds = [dense_factory(num_classes)(n) for n in node_features]
    if_false = tf.fill(tf.shape(preds[0]),
                       value=tf.constant(-np.inf, dtype=tf.float32))
    outer_row_splits = inputs['outer_row_splits'][0]
    outer_row_lengths = op_utils.diff(outer_row_splits)
    if class_masks is not None:
        class_masks = tf.repeat(class_masks, outer_row_lengths, axis=0)
        preds = [tf.where(class_masks, pred, if_false) for pred in preds]
    # from_row_splits = tf.keras.layers.Lambda(_from_row_splits)
    # preds = [from_row_splits([pred, outer_row_splits]) for pred in preds]
    inputs = tf.nest.flatten(inputs)
    return tf.keras.Model(inputs=inputs, outputs=preds)
Ejemplo n.º 5
0
def get_model(input_spec,
              output_spec,
              conv_filters=(16, 32),
              dense_units=(),
              activation='relu'):
    inputs = spec.inputs(input_spec)
    x = inputs
    for f in conv_filters:
        x = tf.keras.layers.Conv2D(f, 3)(x)
        x = VariableMomentumBatchNormalization()(x)
        x = tf.keras.layers.Activation(activation)(x)
    x = tf.keras.layers.Flatten()(x)
    for u in dense_units:
        x = tf.keras.layers.Dense(u)(x)
        x = VariableMomentumBatchNormalization()(x)
        x = tf.keras.layers.Activation(activation)(x)

    num_classes = output_spec.shape[-1]
    logits = tf.keras.layers.Dense(num_classes)(x)

    updater = cb.schedule_updater.ScheduleUpdater(
        schedule=functools.partial(sched.exponential_decay_towards,
                                   initial_value=0.5,
                                   decay_steps=1,
                                   decay_rate=0.5,
                                   asymptote=1.0,
                                   clip_value=0.99,
                                   impl=np),
        variables_func=lambda model: [
            l.momentum
            for l in model.layers
            if isinstance(l, VariableMomentumBatchNormalization)
        ],
        logs_key='batch_norm_momentum')
    cb.aggregator.append(updater)
    return tf.keras.Model(inputs=inputs, outputs=logits)
Ejemplo n.º 6
0
def generalized_semantic_segmenter(
        input_spec,
        output_spec,
        dense_factory=mk_layers.Dense,
        batch_norm_impl=mk_layers.BatchNormalization,
        activation='relu',
        filters0=32):

    batch_norm_fn = get_batch_norm(batch_norm_impl)
    activation_fn = get_activation(activation)

    inputs = spec.inputs(input_spec)
    num_classes = output_spec.shape[-1]

    # class_index = inputs.get('class_index')
    # if class_index is None:
    #     global_features = None
    # else:
    #     global_features = tf.squeeze(tf.keras.layers.Embedding(
    #         num_classes, filters0, input_lenght=1)(class_index),
    #                                  axis=1)
    (
        all_coords,
        flat_rel_coords,
        feature_weights,
        flat_node_indices,
        row_splits,
        sample_indices,
        # outer_row_splits,
    ) = (
        inputs[k] for k in (
            'all_coords',
            'flat_rel_coords',
            'feature_weights',
            'flat_node_indices',
            'row_splits',
            'sample_indices',
            # 'outer_row_splits',
        ))
    # del outer_row_splits

    depth = len(all_coords)
    features = inputs.get('normals')

    filters = filters0
    if features is None:
        features = gen_layers.FeaturelessRaggedConvolution(filters)(
            [flat_rel_coords[0], flat_node_indices[0], feature_weights[0]])
    else:
        raise NotImplementedError()

    activation_kwargs = dict(batch_norm_fn=batch_norm_fn,
                             activation_fn=activation_fn)
    bottleneck_kwargs = dict(dense_factory=dense_factory, **activation_kwargs)

    features = activation_fn(batch_norm_fn(features))
    res_features = []
    for i in range(depth - 1):
        # in place
        features = blocks.in_place_bottleneck(features, flat_rel_coords[2 * i],
                                              flat_node_indices[2 * i],
                                              row_splits[2 * i],
                                              feature_weights[2 * i],
                                              **bottleneck_kwargs)
        res_features.append(features)
        # down sample
        filters *= 2
        features = blocks.down_sample_bottleneck(features,
                                                 flat_rel_coords[2 * i + 1],
                                                 flat_node_indices[2 * i + 1],
                                                 row_splits[2 * i + 1],
                                                 feature_weights[2 * i + 1],
                                                 sample_indices[i],
                                                 filters=filters,
                                                 **bottleneck_kwargs)

    for i in range(depth - 1, 0, -1):
        # in place
        features = blocks.in_place_bottleneck(features,
                                              flat_rel_coords[2 * i],
                                              flat_node_indices[2 * i],
                                              row_splits[2 * i],
                                              feature_weights[2 * i],
                                              filters=filters,
                                              **bottleneck_kwargs)
        if i != depth - 1:
            features = features + res_features.pop()
        # up sample
        filters //= 2
        features = gen_layers.RaggedConvolutionTranspose(
            filters, dense_factory=dense_factory)([
                features, flat_rel_coords[2 * i - 1],
                flat_node_indices[2 * i - 1], row_splits[2 * i - 1],
                feature_weights[2 * i - 1]
            ])
        features = activation_fn(batch_norm_fn(features))

    # final in place
    features = blocks.in_place_bottleneck(features,
                                          flat_rel_coords[0],
                                          flat_node_indices[0],
                                          row_splits[0],
                                          feature_weights[0],
                                          filters=filters,
                                          **bottleneck_kwargs)
    features = features + res_features.pop()

    # per-point classification layer
    logits = dense_factory(num_classes, activation=None,
                           use_bias=True)(features)

    return tf.keras.models.Model(tf.nest.flatten(inputs), logits)
Ejemplo n.º 7
0
def generalized_classifier(input_spec,
                           output_spec,
                           coord_features_fn=get_coord_features,
                           dense_factory=mk_layers.Dense,
                           batch_norm_impl=mk_layers.BatchNormalization,
                           activation='relu',
                           global_filters=(512, 256),
                           filters0=32,
                           global_dropout_impl=None):
    batch_norm_fn = get_batch_norm(batch_norm_impl)
    activation_fn = get_activation(activation)

    inputs = spec.inputs(input_spec)
    num_classes = output_spec.shape[-1]

    # class_index = inputs.get('class_index')
    # if class_index is None:
    #     global_features = None
    # else:
    #     global_features = tf.squeeze(tf.keras.layers.Embedding(
    #         num_classes, filters0, input_lenght=1)(class_index),
    #                                  axis=1)
    (
        all_coords,
        flat_rel_coords,
        flat_node_indices,
        row_splits,
        sample_indices,
        feature_weights,
        # outer_row_splits,
    ) = (
        inputs[k] for k in (
            'all_coords',
            'flat_rel_coords',
            'flat_node_indices',
            'row_splits',
            'sample_indices',
            'feature_weights',
            # 'outer_row_splits',
        ))
    # del outer_row_splits

    depth = len(all_coords)
    coord_features = tuple(coord_features_fn(rc) for rc in flat_rel_coords)
    features = inputs.get('normals')

    filters = filters0
    if features is None:
        features = gen_layers.FeaturelessRaggedConvolution(filters)(
            [flat_rel_coords[0], flat_node_indices[0], feature_weights[0]])
    else:
        raise NotImplementedError()

    activation_kwargs = dict(batch_norm_fn=batch_norm_fn,
                             activation_fn=activation_fn)
    bottleneck_kwargs = dict(dense_factory=dense_factory, **activation_kwargs)

    features = activation_fn(batch_norm_fn(features))
    res_features = []
    for i in range(depth - 1):
        # in place
        features = blocks.in_place_bottleneck(features,
                                              coord_features[2 * i],
                                              flat_node_indices[2 * i],
                                              row_splits[2 * i],
                                              weights=feature_weights[2 * i],
                                              **bottleneck_kwargs)
        res_features.append(features)
        # down sample
        filters *= 2
        features = blocks.down_sample_bottleneck(features,
                                                 coord_features[2 * i + 1],
                                                 flat_node_indices[2 * i + 1],
                                                 row_splits[2 * i + 1],
                                                 feature_weights[2 * i + 1],
                                                 sample_indices[i],
                                                 filters=filters,
                                                 **bottleneck_kwargs)

    features = blocks.in_place_bottleneck(features,
                                          flat_rel_coords[-1],
                                          flat_node_indices[-1],
                                          row_splits[-1],
                                          feature_weights[-1],
                                          filters=filters,
                                          **bottleneck_kwargs)

    # global conv
    global_coords = all_coords[-1]
    features = gen_layers.GlobalRaggedConvolution(
        global_filters[0], dense_factory=mk_layers.Dense)(
            [features, global_coords.flat_values, global_coords.row_splits])
    logits = mlp(global_filters[1:],
                 activate_first=True,
                 final_units=num_classes,
                 batch_norm_impl=batch_norm_impl,
                 activation=activation,
                 dropout_impl=global_dropout_impl)(features)
    return tf.keras.Model(tf.nest.flatten(inputs), logits)
Ejemplo n.º 8
0
def pointnet_classifier(
    input_spec,
    output_spec,
    training=None,
    use_batch_norm=True,
    batch_norm_momentum=0.99,
    dropout_rate=0.3,
    reduction=tf.reduce_max,
    units0=(64, 64),
    units1=(64, 128, 1024),
    global_units=(512, 256),
    transform_reg_weight=0.001 / 2 * 32,  # account for averaging
    transpose_transform=False,
):
    """
    Get a pointnet classifier.

    Args:
        inputs_spec: `tf.keras.layers.InputSpec` representing cloud coordinates.
        training: bool indicating training mode.
        output_spec: InputSpec (shape, dtype attrs) of the output
        use_batch_norm: flag indicating usage of batch norm.
        batch_norm_momentum: momentum value of batch norm. If this is a callable
            it is assumed to be a function of the epoch index, and the returned
            callbacks contain a callback that updates these at the end of each
            epoch. If it is a dict, it is assumed to be a serialized function.
            Ignored if use_batch_norm is False.
        dropout_rate: rate used in Dropout for global mlp.
        reduction: reduction function accepting (., axis) arguments.
        units0: units in initial local mlp network.
        units1: units in second local mlp network.
        global_units: units in global mlp network.
        transform_reg_weight: weight used in l2 regularizer. Note we use the
            sum of squared differences over the matrix dimensions, averaged over
            the batch dimension. The original paper uses the tf.nn.l2_loss
            (which includes a factor of a half) and no batch-dimension
            averaging, hence the odd default value.
        transpose_transform:
            False: what the pointnet paper describes, x' = x @ A.T
                (equivalent to x'.T = A @ x.T)
            True: what the pointnet code implements, x' = x @ A
                (equivalent to x'.T = A.T @ x.T)
            This is significant in the case where there is regularization
            weight, since |I - A.T @ A| != |I - A @ A.T|.

    Returns:
        keras model with logits as outputs and list of necessary callbacks.
    """
    transform_kwargs = dict(transpose_b=not transpose_transform)
    inputs = spec.inputs(input_spec)
    if use_batch_norm and callable(batch_norm_momentum):
        batch_norm_momentum = 0.99  # initial momentum - irrelevant?
        cb.aggregator.append(
            cb.ScheduleUpdater(
                schedule=batch_norm_momentum,
                variables_func=lambda model: [
                    l.momentum for l in model.layers
                    if isinstance(l, VariableMomentumBatchNormalization)
                ]))

    bn_kwargs = dict(use_batch_norm=use_batch_norm,
                     batch_norm_momentum=batch_norm_momentum)
    num_classes = output_spec.shape[-1]
    cloud = inputs
    transform0 = feature_transform_net(cloud,
                                       3,
                                       training=training,
                                       **bn_kwargs)
    cloud = layers.Lambda(apply_transform,
                          arguments=transform_kwargs)([cloud, transform0
                                                       ])  # TF-COMPAT
    cloud = mlp(cloud, units0, training=training, **bn_kwargs)

    transform1 = feature_transform_net(
        cloud,
        units0[-1],
        transform_reg_weight=transform_reg_weight,
        **bn_kwargs)
    cloud = layers.Lambda(apply_transform,
                          arguments=transform_kwargs)([cloud, transform1
                                                       ])  # TF-COMPAT

    cloud = mlp(cloud, units1, training=training, **bn_kwargs)

    features = layers.Lambda(reduction,
                             arguments=dict(axis=-2))(cloud)  # TF-COMPAT
    features = mlp(features,
                   global_units,
                   training=training,
                   dropout_rate=dropout_rate,
                   **bn_kwargs)
    logits = tf.keras.layers.Dense(num_classes)(features)

    model = tf.keras.models.Model(inputs=tf.nest.flatten(inputs),
                                  outputs=logits)

    return model