コード例 #1
0
 def test_fill_node_state(self, node_size):
   """Tests for filling the node state with a constant content."""
   for g in self.graphs_dicts_in:
     g["n_node"] = g["nodes"].shape[0]
     g.pop("nodes")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   n_nodes = np.sum(self.reference_graph.n_node)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size)
   self.assertAllEqual((n_nodes, node_size),
                       graphs_tuple.nodes.get_shape().as_list())
コード例 #2
0
 def test_fill_state_user_specified_types(self, dtype):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype)
   self.assertEqual(dtype, graphs_tuple.edges.dtype)
   self.assertEqual(dtype, graphs_tuple.nodes.dtype)
   self.assertEqual(dtype, graphs_tuple.globals.dtype)
コード例 #3
0
 def test_fill_state_default_types(self):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1)
   graphs_tuple = utils_tf.set_zero_global_features(
       graphs_tuple, global_size=1)
   self.assertEqual(tf.float32, graphs_tuple.edges.dtype)
   self.assertEqual(tf.float32, graphs_tuple.nodes.dtype)
   self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
コード例 #4
0
 def test_fill_node_state_dynamic(self, node_size):
   """Tests for filling the node state with a constant content."""
   for g in self.graphs_dicts_in:
     g["n_node"] = g["nodes"].shape[0]
     g.pop("nodes")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = graphs_tuple._replace(
       n_node=tf.constant(
           graphs_tuple.n_node, shape=graphs_tuple.n_node.get_shape()))
   n_nodes = np.sum(self.reference_graph.n_node)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size)
   actual_nodes = graphs_tuple.nodes.numpy()
   self.assertNDArrayNear(
       np.zeros((n_nodes, node_size)), actual_nodes, err=1e-4)
コード例 #5
0
def zeros_graph(sample_graph, edge_size, node_size, global_size):
    zeros_graphs = sample_graph.replace(nodes=None, edges=None, globals=None)
    zeros_graphs = utils_tf.set_zero_edge_features(zeros_graphs, edge_size)
    zeros_graphs = utils_tf.set_zero_node_features(zeros_graphs, node_size)
    zeros_graphs = utils_tf.set_zero_global_features(zeros_graphs, global_size)
    return zeros_graphs