def down_sample_block(node_features, edge_features, neigh_indices, sample_indices, name, filters=None, conv_shortcut=False): shortcut = tf.gather(node_features, sample_indices) if conv_shortcut: shortcut = Dense(filters * 4, name=name + '_shortcut_dense')(shortcut) shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut) else: filters = _get_filters(filters, node_features.shape[-1]) out_size = tf.shape(sample_indices, out_type=tf.int64)[0] x = Dense(filters, name=name + '_bottleneck_0_dense')(node_features) x = BatchNormalization(name=name + '_bottleneck_0_bn')(x) x = Activation('relu', name=name + '_bottleneck_0_relu')(x) x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')( [x, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_conv_bn')(x) x = Activation('relu', name=name + '_conv_relu')(x) x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x) x = BatchNormalization(name=name + '_bottleneck_1_bn')(x) x = Add(name=name + '_add')([shortcut, x]) x = Activation('relu', name=name + '_out')(x) return x
def in_place_block(node_features, edge_features, neigh_indices, name, filters=None, conv_shortcut=False): """ Residual block for in-place convolution. Analagous to image convolutions with `stride == 1, padding='SAME'`. Args: node_features: [N, F] initial node features. edge_features: [E, F_edge] edge features. neigh_indices: [E, 2] sparse indices of neighbors. name: string name. filters: number of filters in bottleneck layer. If None, uses `F // 4`. conv_shortcut: if True, uses a dense layer on shortcut connection, otherwise uses identity and `filters` must be `None` or `F // 4`, and `F % 4 == 0`. Returns: [N, filters * 4] output features in the same node ordering. """ out_size = tf.shape(node_features, out_type=tf.int64)[0] if conv_shortcut: if filters is None: raise ValueError( 'filters must be provided if conv_shortcut is True') shortcut = Dense(4 * filters, name=name + '_shortcut_dense')(node_features) shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut) else: shortcut = node_features filters = _get_filters(filters, node_features.shape[-1]) x = Dense(filters * 4, name=name + '_bottleneck_0_dense')(node_features) x = BatchNormalization(name=name + '_bottleneck_0_bn')(x) x = Activation('relu', name=name + '_bottleneck_0_relu')(x) x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')( [x, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_conv_bn')(x) x = Activation('relu', name=name + '_conv_relu')(x) x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x) x = BatchNormalization(name=name + '_bottleneck_1_bn')(x) x = Add(name=name + '_add')([shortcut, x]) x = Activation('relu', name=name + '_out')(x) return x
def mlp_layer(flat_features, units, activation='relu', use_batch_normalization=False, dropout_rate=None): """ Basic multi-layer perceptron layer + call. Args: flat_features: [N, units_in] float32 non-ragged float features units: number of output features activation: used in Dense dropout_rate: rate used in dropout (no dropout if this is None) use_batch_normalization: applied after dropout if True Returns: [N, units] float32 output features """ flat_features = Dense(units, activation=activation, use_bias=not use_batch_normalization)(flat_features) if dropout_rate is not None: flat_features = Dropout(rate=dropout_rate)(flat_features) if use_batch_normalization: flat_features = BatchNormalization()(flat_features) return flat_features
def classifier_logits(inputs, num_classes, features_fn=resnet_features, final_node_units=1024, mlp_fn=mlp): (sample_indices, in_place_indices, in_place_rel_coords, down_sample_indices, down_sample_rel_coords, row_splits) = (inputs.get(k) for k in ('sample_indices', 'in_place_indices', 'in_place_rel_coords', 'down_sample_indices', 'down_sample_rel_coords', 'row_splits')) (out_features, in_place_edge_features, in_place_weights, down_sample_edge_features, down_sample_weights) = features_fn( sample_indices, in_place_rel_coords, in_place_indices, down_sample_rel_coords, down_sample_indices) del (in_place_edge_features, in_place_weights, down_sample_edge_features, down_sample_weights) features = out_features[-1] features = Dense(final_node_units, name='node_final_dense')(features) features = BatchNormalization(name='node_final_bn')(features) features = tf.keras.layers.Lambda(_from_row_splits)( [features, row_splits[-1]]) features = tf.keras.layers.Lambda(tf.reduce_max, arguments=dict(axis=1))(features) features = Activation('relu', name='node_final_relu')(features) features = mlp_fn(features) logits = Dense(num_classes, name='logits')(features) return logits
def mlp(x, units=(512, 256), dropout_rate=0.4): for i, u in enumerate(units): x = Dense(u, name='mlp_{}_dense'.format(i))(x) x = BatchNormalization(name='mlp_{}_bn'.format(i))(x) x = Activation('relu', name='mlp_{}_relu'.format(i))(x) if dropout_rate is not None: x = tf.keras.layers.Dropout(dropout_rate)(x) return x
def up_sample_combine(upper_node_features, node_features, edge_features, neigh_indices, name): out_size = tf.shape(upper_node_features, out_type=tf.int64)[0] x = sparse_layers.SparseCloudConvolution(upper_node_features.shape[-1])( [node_features, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_bn')(x) x = Add(name=name + '_add')([upper_node_features, x]) x = Activation('relu', name=name + '_out')(x) return x
def generalized_activation(features, activation='relu', add_bias=False, use_batch_normalization=False, dropout_rate=None): """ Generalized activation to be performed (presumably) after a learned layer. Can include (in order) * bias addition; * standard activation; * batch normalization; and/or * dropout. """ if add_bias and not use_batch_normalization: features = bias.add_bias(features) features = tf.keras.layers.Lambda( tf.keras.activations.get(activation))(features) if use_batch_normalization: features = BatchNormalization()(features) if dropout_rate: features = Dropout(dropout_rate)(features) return features
def initial_block(edge_features, neigh_indices, filters=64, name='initial'): x = sparse_layers.FeaturelessSparseCloudConvolution( filters, name=name + '_conv')([edge_features, neigh_indices]) x = BatchNormalization(name=name + '_bn')(x) x = Activation('relu', name=name + '_relu')(x) return x