def down_sample_block(node_features, edge_features, neigh_indices, sample_indices, name, filters=None, conv_shortcut=False): shortcut = tf.gather(node_features, sample_indices) if conv_shortcut: shortcut = Dense(filters * 4, name=name + '_shortcut_dense')(shortcut) shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut) else: filters = _get_filters(filters, node_features.shape[-1]) out_size = tf.shape(sample_indices, out_type=tf.int64)[0] x = Dense(filters, name=name + '_bottleneck_0_dense')(node_features) x = BatchNormalization(name=name + '_bottleneck_0_bn')(x) x = Activation('relu', name=name + '_bottleneck_0_relu')(x) x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')( [x, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_conv_bn')(x) x = Activation('relu', name=name + '_conv_relu')(x) x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x) x = BatchNormalization(name=name + '_bottleneck_1_bn')(x) x = Add(name=name + '_add')([shortcut, x]) x = Activation('relu', name=name + '_out')(x) return x
def final_up_sample(node_features, edge_features, neigh_indices, out_size, num_classes, name='up_sample_final'): return sparse_layers.SparseCloudConvolution(num_classes)( [node_features, edge_features, neigh_indices, out_size])
def up_sample_combine(upper_node_features, node_features, edge_features, neigh_indices, name): out_size = tf.shape(upper_node_features, out_type=tf.int64)[0] x = sparse_layers.SparseCloudConvolution(upper_node_features.shape[-1])( [node_features, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_bn')(x) x = Add(name=name + '_add')([upper_node_features, x]) x = Activation('relu', name=name + '_out')(x) return x
def conv(self, node_features, **layer_kwargs): edge_features = self.weighted_edge_features if node_features is None: layer = sparse_layers.FeaturelessSparseCloudConvolution( **layer_kwargs) return layer([edge_features, self.sparse_indices]) else: layer = sparse_layers.SparseCloudConvolution(**layer_kwargs) return layer([ node_features, edge_features, self.sparse_indices, self._out_cloud.trained_total_size ])
def global_conv(self, node_features, edge_features_fn, edge_weights_fn, **layer_kwargs): coords = tf.transpose(self.trained_coords, (1, 0)) edge_features = edge_features_fn(coords) edge_weights = edge_weights_fn(coords) sparse_indices = self.trained_global_sparse_indices edge_weights = row_normalize(edge_weights, tf.gather(sparse_indices, 0, axis=1)) edge_features = edge_features * tf.expand_dims(edge_weights, axis=0) return sparse_layers.SparseCloudConvolution(**layer_kwargs)([ node_features, edge_features, sparse_indices, self.trained_coords.nrows(out_type=tf.int64) ])
def in_place_block(node_features, edge_features, neigh_indices, name, filters=None, conv_shortcut=False): """ Residual block for in-place convolution. Analagous to image convolutions with `stride == 1, padding='SAME'`. Args: node_features: [N, F] initial node features. edge_features: [E, F_edge] edge features. neigh_indices: [E, 2] sparse indices of neighbors. name: string name. filters: number of filters in bottleneck layer. If None, uses `F // 4`. conv_shortcut: if True, uses a dense layer on shortcut connection, otherwise uses identity and `filters` must be `None` or `F // 4`, and `F % 4 == 0`. Returns: [N, filters * 4] output features in the same node ordering. """ out_size = tf.shape(node_features, out_type=tf.int64)[0] if conv_shortcut: if filters is None: raise ValueError( 'filters must be provided if conv_shortcut is True') shortcut = Dense(4 * filters, name=name + '_shortcut_dense')(node_features) shortcut = BatchNormalization(name=name + '_shortcut_bn')(shortcut) else: shortcut = node_features filters = _get_filters(filters, node_features.shape[-1]) x = Dense(filters * 4, name=name + '_bottleneck_0_dense')(node_features) x = BatchNormalization(name=name + '_bottleneck_0_bn')(x) x = Activation('relu', name=name + '_bottleneck_0_relu')(x) x = sparse_layers.SparseCloudConvolution(filters, name=name + '_conv')( [x, edge_features, neigh_indices, out_size]) x = BatchNormalization(name=name + '_conv_bn')(x) x = Activation('relu', name=name + '_conv_relu')(x) x = Dense(filters * 4, name=name + '_bottleneck_1_dense')(x) x = BatchNormalization(name=name + '_bottleneck_1_bn')(x) x = Add(name=name + '_add')([shortcut, x]) x = Activation('relu', name=name + '_out')(x) return x