Ejemplo n.º 1
0
    def test_feature_steered_convolution_layer_initializer(self):
        """Tests a custom variable initializer."""
        data = np.array(((1.0, 1.0), (-1.0, 1.0), (-1.0, -1.0), (1.0, -1.0)))
        neighbors_indices = np.array(
            ((0, 0), (0, 1), (0, 3), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2),
             (2, 3), (3, 0), (3, 2), (3, 3)))
        neighbors = tf.SparseTensor(neighbors_indices,
                                    np.ones(shape=(12, )) / 3.0,
                                    dense_shape=(4, 4))
        initializer = tf.compat.v1.keras.initializers.zeros()

        if tf.executing_eagerly():
            layer = gc_layer.FeatureSteeredConvolutionKerasLayer(
                translation_invariant=False, initializer=initializer)
            output = layer(inputs=[data, neighbors], sizes=None)
        else:
            out = gc_layer.feature_steered_convolution_layer(
                data=data,
                neighbors=neighbors,
                sizes=None,
                translation_invariant=False,
                initializer=initializer)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            output = self.evaluate(out)

        # All zeros initializer should result in all zeros output.
        self.assertAllEqual(output, np.zeros_like(data))
Ejemplo n.º 2
0
def mesh_encoder(batch_mesh_data, num_filters, output_dim, conv_layer_dims):

    """A mesh encoder using feature steered graph convolutions.

        The shorthands used below are
        `B`: Batch size.
        `V`: The maximum number of vertices over all meshes in the batch.
        `D`: The number of dimensions of input vertex features, D=3 if vertex
            positions are used as features.

    Args:
        batch_mesh_data: A mesh_data dict with following keys
        'vertices': A [B, V, D] `float32` tensor of vertex features, possibly
            0-padded.
        'neighbors': A [B, V, V] `float32` sparse tensor of edge weights.
        'num_vertices': A [B] `int32` tensor of number of vertices per mesh.
        num_filters: The number of weight matrices to be used in feature steered
        graph conv.
        output_dim: A dimension of output per vertex features.
        conv_layer_dims: A list of dimensions used in graph convolution layers.

    Returns:
        vertex_features: A [B, V, output_dim] `float32` tensor of per vertex
        features.
    """
    batch_vertices = batch_mesh_data['vertices']
  
    # Linear: N x D --> N x 16.
    vertex_features = tf.keras.layers.Conv1D(16, 1, name='lin16')(batch_vertices)

    # graph convolution layers
    for dim in conv_layer_dims:
        with tf.variable_scope('conv_%d' % dim):
            vertex_features = graph_conv.feature_steered_convolution_layer(
                vertex_features,
                batch_mesh_data['neighbors'],
                batch_mesh_data['num_vertices'],
                num_weight_matrices=num_filters,
                num_output_channels=dim,
                translation_invariant=True)
            vertex_features = tf.nn.relu(vertex_features)

    # Linear: N x 128 --> N x 256.
    vertex_features = tf.keras.layers.Conv1D(256, 1, name='lin256')(vertex_features)
    vertex_features = tf.nn.relu(vertex_features)

    # Linear: N x 256 --> N x output_dim.
    vertex_features = tf.keras.layers.Conv1D(output_dim, 1, name='lin_output')(vertex_features)

    return vertex_features
Ejemplo n.º 3
0
    def test_feature_steered_convolution_layer_training(self):
        """Test a simple training loop."""
        # Generate a small valid input for a simple training task.
        # Four corners of a square.
        data = np.array(((1.0, 1.0), (-1.0, 1.0), (-1.0, -1.0), (1.0, -1.0)))
        neighbors_indices = np.array(
            ((0, 0), (0, 1), (0, 3), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2),
             (2, 3), (3, 0), (3, 2), (3, 3)))
        neighbors = tf.SparseTensor(neighbors_indices,
                                    np.ones(shape=(12, )) / 3.0,
                                    dense_shape=(4, 4))
        # Desired output is arbitrary.
        labels = np.reshape([-1.0, -0.5, 0.5, 1.0], (-1, 1))
        num_training_iterations = 5

        if tf.executing_eagerly():
            with tf.GradientTape(persistent=True) as tape:
                layer = gc_layer.FeatureSteeredConvolutionKerasLayer(
                    translation_invariant=False,
                    num_weight_matrices=1,
                    num_output_channels=1)
                output = layer(inputs=[data, neighbors], sizes=None)
                loss = tf.nn.l2_loss(output - labels)

            trainable_variables = layer.trainable_variables
            for _ in range(num_training_iterations):
                grads = tape.gradient(loss, trainable_variables)
                tf.compat.v1.train.GradientDescentOptimizer(
                    1e-4).apply_gradients(list(zip(grads,
                                                   trainable_variables)))
        else:
            output = gc_layer.feature_steered_convolution_layer(
                data=data,
                neighbors=neighbors,
                sizes=None,
                translation_invariant=False,
                num_weight_matrices=1,
                num_output_channels=1)
            train_op = tf.compat.v1.train.GradientDescentOptimizer(
                1e-4).minimize(tf.nn.l2_loss(output - labels))
            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.initialize_all_variables())
                for _ in range(num_training_iterations):
                    sess.run(train_op)
Ejemplo n.º 4
0
 def _run_convolution():
     """Run the appropriate feature steered convolution layer."""
     if tf.executing_eagerly():
         try:
             output = layer(inputs=[data, neighbors], sizes=None)
         except Exception as e:  # pylint: disable=broad-except
             self.fail("Exception raised: %s" % str(e))
     else:
         try:
             output = gc_layer.feature_steered_convolution_layer(
                 data=data,
                 neighbors=neighbors,
                 sizes=None,
                 translation_invariant=translation_invariant,
                 num_weight_matrices=num_weight_matrices,
                 num_output_channels=out_channels,
                 var_name=name_scope)
         except Exception as e:  # pylint: disable=broad-except
             self.fail("Exception raised: %s" % str(e))
     return output