def test_get_padding_values(self): shape = TensorGraphShape(num_nodes=42, node_dims=8, edge_dims=6) graph_padding = shape.get_padding_values() padding_node = graph_padding[NODE_FEATURES] padding_edge = graph_padding[EDGE_FEATURES] self.assertAllEqual(padding_node, 0.0) self.assertAllEqual(padding_edge, 0.0)
def test_constructor_defaults(self): shape = TensorGraphShape() assert shape.num_nodes is None assert shape.node_dims == 1 assert shape.edge_dims == 1 assert shape.node_dtype == tf.float32 assert shape.edge_dtype == tf.float32
def test_constructor_with_values(self): shape = TensorGraphShape(num_nodes=42, node_dims=8, edge_dims=6, node_dtype=tf.float64, edge_dtype=tf.int32) assert shape.num_nodes == 42 assert shape.node_dims == 8 assert shape.edge_dims == 6 assert shape.node_dtype == tf.float64 assert shape.edge_dtype == tf.int32
def test_simple_graph(self): shape = TensorGraphShape(num_nodes=42, node_dims=8, edge_dims=6, node_dtype=tf.float64, edge_dtype=tf.float32) input = TensorGraphInput(shape) input_node_features = input[NODE_FEATURES] input_edge_features = input[EDGE_FEATURES] self.assertAllEqual(input_node_features.shape, (None, 42, 8)) self.assertAllEqual(input_edge_features.shape, (None, 6, 42, 42)) assert input_node_features.name.startswith('graph_node_features') assert input_edge_features.name.startswith('graph_edge_features')
def test_define_tfrecord_features_with_name(self): shape = TensorGraphShape(num_nodes=42, node_dims=8, edge_dims=6, node_dtype=tf.float64, edge_dtype=tf.float32) feat = define_tfrecord_features(shape, name='test') assert feat == { 'test:num_nodes': tf.io.FixedLenFeature([], tf.int64), 'test:node_features': tf.io.VarLenFeature(tf.float64), 'test:edge_features': tf.io.VarLenFeature(tf.float32), }