Exemple #1
0
def down_sample_block(node_features,
                      edge_features,
                      neigh_indices,
                      sample_indices,
                      name,
                      filters=None,
                      conv_shortcut=False):
    shortcut = tf.gather(node_features, sample_indices)
    if conv_shortcut:
        shortcut = Dense(filters * 4, name=name + '_shortcut_dense')(shortcut)
        shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut)
    else:
        filters = _get_filters(filters, node_features.shape[-1])
    out_size = tf.shape(sample_indices, out_type=tf.int64)[0]
    x = Dense(filters, name=name + '_bottleneck_0_dense')(node_features)
    x = BatchNormalization(name=name + '_bottleneck_0_bn')(x)
    x = Activation('relu', name=name + '_bottleneck_0_relu')(x)
    x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')(
        [x, edge_features, neigh_indices, out_size])
    x = BatchNormalization(name=name + '_conv_bn')(x)
    x = Activation('relu', name=name + '_conv_relu')(x)
    x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x)
    x = BatchNormalization(name=name + '_bottleneck_1_bn')(x)
    x = Add(name=name + '_add')([shortcut, x])
    x = Activation('relu', name=name + '_out')(x)
    return x
Exemple #2
0
def in_place_block(node_features,
                   edge_features,
                   neigh_indices,
                   name,
                   filters=None,
                   conv_shortcut=False):
    """
    Residual block for in-place convolution.

    Analagous to image convolutions with `stride == 1, padding='SAME'`.

    Args:
        node_features: [N, F] initial node features.
        edge_features: [E, F_edge] edge features.
        neigh_indices: [E, 2] sparse indices of neighbors.
        name: string name.
        filters: number of filters in bottleneck layer. If None, uses `F // 4`.
        conv_shortcut: if True, uses a dense layer on shortcut connection,
            otherwise uses identity and `filters` must be `None` or `F // 4`,
            and `F % 4 == 0`.

    Returns:
        [N, filters * 4] output features in the same node ordering.
    """
    out_size = tf.shape(node_features, out_type=tf.int64)[0]
    if conv_shortcut:
        if filters is None:
            raise ValueError(
                'filters must be provided if conv_shortcut is True')
        shortcut = Dense(4 * filters,
                         name=name + '_shortcut_dense')(node_features)
        shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut)
    else:
        shortcut = node_features
        filters = _get_filters(filters, node_features.shape[-1])

    x = Dense(filters * 4, name=name + '_bottleneck_0_dense')(node_features)
    x = BatchNormalization(name=name + '_bottleneck_0_bn')(x)
    x = Activation('relu', name=name + '_bottleneck_0_relu')(x)
    x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')(
        [x, edge_features, neigh_indices, out_size])
    x = BatchNormalization(name=name + '_conv_bn')(x)
    x = Activation('relu', name=name + '_conv_relu')(x)
    x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x)
    x = BatchNormalization(name=name + '_bottleneck_1_bn')(x)
    x = Add(name=name + '_add')([shortcut, x])
    x = Activation('relu', name=name + '_out')(x)
    return x
Exemple #3
0
def mlp_layer(flat_features,
              units,
              activation='relu',
              use_batch_normalization=False,
              dropout_rate=None):
    """
    Basic multi-layer perceptron layer + call.

    Args:
        flat_features: [N, units_in] float32 non-ragged float features
        units: number of output features
        activation: used in Dense
        dropout_rate: rate used in dropout (no dropout if this is None)
        use_batch_normalization: applied after dropout if True

    Returns:
        [N, units] float32 output features
    """
    flat_features = Dense(units,
                          activation=activation,
                          use_bias=not use_batch_normalization)(flat_features)
    if dropout_rate is not None:
        flat_features = Dropout(rate=dropout_rate)(flat_features)
    if use_batch_normalization:
        flat_features = BatchNormalization()(flat_features)
    return flat_features
Exemple #4
0
def classifier_logits(inputs,
                      num_classes,
                      features_fn=resnet_features,
                      final_node_units=1024,
                      mlp_fn=mlp):
    (sample_indices, in_place_indices, in_place_rel_coords,
     down_sample_indices, down_sample_rel_coords,
     row_splits) = (inputs.get(k)
                    for k in ('sample_indices', 'in_place_indices',
                              'in_place_rel_coords', 'down_sample_indices',
                              'down_sample_rel_coords', 'row_splits'))
    (out_features, in_place_edge_features, in_place_weights,
     down_sample_edge_features, down_sample_weights) = features_fn(
         sample_indices, in_place_rel_coords, in_place_indices,
         down_sample_rel_coords, down_sample_indices)
    del (in_place_edge_features, in_place_weights, down_sample_edge_features,
         down_sample_weights)
    features = out_features[-1]
    features = Dense(final_node_units, name='node_final_dense')(features)
    features = BatchNormalization(name='node_final_bn')(features)
    features = tf.keras.layers.Lambda(_from_row_splits)(
        [features, row_splits[-1]])
    features = tf.keras.layers.Lambda(tf.reduce_max,
                                      arguments=dict(axis=1))(features)
    features = Activation('relu', name='node_final_relu')(features)
    features = mlp_fn(features)
    logits = Dense(num_classes, name='logits')(features)
    return logits
Exemple #5
0
def mlp(x, units=(512, 256), dropout_rate=0.4):
    for i, u in enumerate(units):
        x = Dense(u, name='mlp_{}_dense'.format(i))(x)
        x = BatchNormalization(name='mlp_{}_bn'.format(i))(x)
        x = Activation('relu', name='mlp_{}_relu'.format(i))(x)
        if dropout_rate is not None:
            x = tf.keras.layers.Dropout(dropout_rate)(x)
    return x
Exemple #6
0
def up_sample_combine(upper_node_features, node_features, edge_features,
                      neigh_indices, name):
    out_size = tf.shape(upper_node_features, out_type=tf.int64)[0]
    x = sparse_layers.SparseCloudConvolution(upper_node_features.shape[-1])(
        [node_features, edge_features, neigh_indices, out_size])
    x = BatchNormalization(name=name + '_bn')(x)
    x = Add(name=name + '_add')([upper_node_features, x])
    x = Activation('relu', name=name + '_out')(x)
    return x
Exemple #7
0
def generalized_activation(features,
                           activation='relu',
                           add_bias=False,
                           use_batch_normalization=False,
                           dropout_rate=None):
    """
    Generalized activation to be performed (presumably) after a learned layer.

    Can include (in order)
        * bias addition;
        * standard activation;
        * batch normalization; and/or
        * dropout.
    """
    if add_bias and not use_batch_normalization:
        features = bias.add_bias(features)
    features = tf.keras.layers.Lambda(
        tf.keras.activations.get(activation))(features)
    if use_batch_normalization:
        features = BatchNormalization()(features)
    if dropout_rate:
        features = Dropout(dropout_rate)(features)
    return features
Exemple #8
0
def initial_block(edge_features, neigh_indices, filters=64, name='initial'):
    x = sparse_layers.FeaturelessSparseCloudConvolution(
        filters, name=name + '_conv')([edge_features, neigh_indices])
    x = BatchNormalization(name=name + '_bn')(x)
    x = Activation('relu', name=name + '_relu')(x)
    return x