示例#1
0
def get_data():
  inputs_tr, targets_tr, sort_indices_tr, _ = create_data(
      batch_size_tr, num_elements_min_max_tr)
  inputs_tr = utils_tf.set_zero_edge_features(inputs_tr, 1)
  inputs_tr = utils_tf.set_zero_global_features(inputs_tr, 1)
  # Test/generalization.
  inputs_ge, targets_ge, sort_indices_ge, _ = create_data(
      batch_size_ge, num_elements_min_max_ge)
  inputs_ge = utils_tf.set_zero_edge_features(inputs_ge, 1)
  inputs_ge = utils_tf.set_zero_global_features(inputs_ge, 1)

  targets_tr = utils_tf.set_zero_global_features(targets_tr, 1)
  targets_ge = utils_tf.set_zero_global_features(targets_ge, 1)

  return inputs_tr, targets_tr, sort_indices_tr, inputs_ge, targets_ge, sort_indices_ge
示例#2
0
 def test_fill_global_state(self, global_size):
   """Tests for filling the global state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("globals")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   n_graphs = self.reference_graph.n_edge.shape[0]
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size)
   self.assertAllEqual((n_graphs, global_size),
                       graphs_tuple.globals.get_shape().as_list())
示例#3
0
 def test_fill_state_user_specified_types(self, dtype):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype)
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype)
   self.assertEqual(dtype, graphs_tuple.edges.dtype)
   self.assertEqual(dtype, graphs_tuple.nodes.dtype)
   self.assertEqual(dtype, graphs_tuple.globals.dtype)
示例#4
0
 def test_fill_state_default_types(self):
   """Tests that the features are created with the correct default type."""
   for g in self.graphs_dicts_in:
     g.pop("nodes")
     g.pop("globals")
     g.pop("edges")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
   graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1)
   graphs_tuple = utils_tf.set_zero_global_features(
       graphs_tuple, global_size=1)
   self.assertEqual(tf.float32, graphs_tuple.edges.dtype)
   self.assertEqual(tf.float32, graphs_tuple.nodes.dtype)
   self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
示例#5
0
 def test_fill_global_state_dynamic(self, global_size):
   """Tests for filling the global state with a constant content."""
   for g in self.graphs_dicts_in:
     g.pop("globals")
   graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
   # Hide global shape information
   graphs_tuple = graphs_tuple._replace(
       n_node=tf.placeholder_with_default(graphs_tuple.n_node, shape=[None]))
   n_graphs = self.reference_graph.n_edge.shape[0]
   graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size)
   with self.test_session() as sess:
     actual_globals = sess.run(graphs_tuple.globals)
   self.assertNDArrayNear(
       np.zeros((n_graphs, global_size)), actual_globals, err=1e-4)
示例#6
0
def zeros_graph(sample_graph, edge_size, node_size, global_size):
    zeros_graphs = sample_graph.replace(nodes=None, edges=None, globals=None)
    zeros_graphs = utils_tf.set_zero_edge_features(zeros_graphs, edge_size)
    zeros_graphs = utils_tf.set_zero_node_features(zeros_graphs, node_size)
    zeros_graphs = utils_tf.set_zero_global_features(zeros_graphs, global_size)
    return zeros_graphs
def main():
    all_graphs = test_graph()
    inputs = [all_graphs[0][0]]
    in_graphs = utils_tf.data_dicts_to_graphs_tuple(inputs)
    in_graphs = utils_tf.set_zero_global_features(in_graphs, 1)
    dtype_shape_from_graphs_tuple(in_graphs)