def test_constructor_defaults(self):
        shape = TensorGraphShape()

        assert shape.num_nodes is None
        assert shape.node_dims == 1
        assert shape.edge_dims == 1
        assert shape.node_dtype == tf.float32
        assert shape.edge_dtype == tf.float32
    def test_get_padding_values(self):
        shape = TensorGraphShape(num_nodes=42, node_dims=8, edge_dims=6)

        graph_padding = shape.get_padding_values()
        padding_node = graph_padding[NODE_FEATURES]
        padding_edge = graph_padding[EDGE_FEATURES]

        self.assertAllEqual(padding_node, 0.0)
        self.assertAllEqual(padding_edge, 0.0)
    def test_constructor_with_values(self):
        shape = TensorGraphShape(num_nodes=42,
                                 node_dims=8,
                                 edge_dims=6,
                                 node_dtype=tf.float64,
                                 edge_dtype=tf.int32)

        assert shape.num_nodes == 42
        assert shape.node_dims == 8
        assert shape.edge_dims == 6
        assert shape.node_dtype == tf.float64
        assert shape.edge_dtype == tf.int32
Esempio n. 4
0
    def test_simple_graph(self):
        shape = TensorGraphShape(num_nodes=42,
                                 node_dims=8, edge_dims=6,
                                 node_dtype=tf.float64, edge_dtype=tf.float32)

        input = TensorGraphInput(shape)
        input_node_features = input[NODE_FEATURES]
        input_edge_features = input[EDGE_FEATURES]

        self.assertAllEqual(input_node_features.shape, (None, 42, 8))
        self.assertAllEqual(input_edge_features.shape, (None, 6, 42, 42))

        assert input_node_features.name.startswith('graph_node_features')
        assert input_edge_features.name.startswith('graph_edge_features')
Esempio n. 5
0
    def test_define_tfrecord_features_with_name(self):
        shape = TensorGraphShape(num_nodes=42,
                                 node_dims=8,
                                 edge_dims=6,
                                 node_dtype=tf.float64,
                                 edge_dtype=tf.float32)

        feat = define_tfrecord_features(shape, name='test')

        assert feat == {
            'test:num_nodes': tf.io.FixedLenFeature([], tf.int64),
            'test:node_features': tf.io.VarLenFeature(tf.float64),
            'test:edge_features': tf.io.VarLenFeature(tf.float32),
        }