コード例 #1
0
 def _get_input_graph(self, none_fields=None):
     if none_fields is None:
         none_fields = []
     input_graph = utils_tf.data_dicts_to_graphs_tuple(
         [SMALL_GRAPH_1, SMALL_GRAPH_2, SMALL_GRAPH_3, SMALL_GRAPH_4])
     input_graph = input_graph.map(lambda _: None, none_fields)
     return input_graph
コード例 #2
0
 def test_output_values(self, broadcaster, expected):
     """Test the broadcasted output value."""
     input_graph = utils_tf.data_dicts_to_graphs_tuple(
         [SMALL_GRAPH_1, SMALL_GRAPH_2])
     broadcasted = broadcaster(input_graph)
     with self.test_session() as sess:
         broadcasted_out = sess.run(broadcasted)
     self.assertNDArrayNear(np.array(expected, dtype=np.float32),
                            broadcasted_out,
                            err=1e-4)
コード例 #3
0
 def test_fill_edge_state(self, edge_size):
     """Tests for filling the edge state with a constant content."""
     for g in self.graphs_dicts_in:
         g.pop("edges")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     n_edges = np.sum(self.reference_graph.n_edge)
     graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size)
     self.assertAllEqual((n_edges, edge_size),
                         graphs_tuple.edges.get_shape().as_list())
コード例 #4
0
ファイル: main.py プロジェクト: Tubbz-alt/graph_net
def create_data_ops(batch_size, num_elements_min_max):
    """Returns graphs containg the inputs and targets for classification.

  Refer to create_data_dicts_tf and create_linked_list_target for more details.

  Args:
    batch_size: batch size for the `input_graphs`.
    num_elements_min_max: a 2-`tuple` of `int`s which define the [lower, upper)
      range of the number of elements per list.

  Returns:
    inputs_op: a `graphs.GraphsTuple` which contains the input list as a graph.
    targets_op: a `graphs.GraphsTuple` which contains the target as a graph.
    sort_indices_op: a `graphs.GraphsTuple` which contains the sort indices of
      the list elements a graph.
    ranks_op: a `graphs.GraphsTuple` which contains the ranks of the list
      elements as a graph.

  data_dicts_to_graphs_tuple:
          Creates a `graphs.GraphsTuple` containing tensors from data dicts.
  """
    inputs_op, sort_indices_op, ranks_op = create_graph_dicts_tf(
        batch_size, num_elements_min_max)

    # show["inputs_graphs"] = inputs_op
    # show["sort_indices_graphs"] = sort_indices_op
    # show["ranks_graphs"] = ranks_op

    inputs_op = utils_tf.data_dicts_to_graphs_tuple(inputs_op)
    sort_indices_op = utils_tf.data_dicts_to_graphs_tuple(sort_indices_op)
    ranks_op = utils_tf.data_dicts_to_graphs_tuple(ranks_op)

    inputs_op = utils_tf.fully_connect_graph_dynamic(
        inputs_op)  # Adds edges to a graph by fully-connecting the nodes.
    sort_indices_op = utils_tf.fully_connect_graph_dynamic(sort_indices_op)
    ranks_op = utils_tf.fully_connect_graph_dynamic(ranks_op)

    targets_op = create_linked_list_target(batch_size, sort_indices_op)
    nodes = tf.concat((targets_op.nodes, 1.0 - targets_op.nodes), axis=1)
    edges = tf.concat((targets_op.edges, 1.0 - targets_op.edges), axis=1)
    targets_op = targets_op._replace(nodes=nodes, edges=edges)

    return inputs_op, targets_op, sort_indices_op, ranks_op
コード例 #5
0
 def test_fill_global_state(self, global_size):
     """Tests for filling the global state with a constant content."""
     for g in self.graphs_dicts_in:
         g.pop("globals")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     n_graphs = self.reference_graph.n_edge.shape[0]
     graphs_tuple = utils_tf.set_zero_global_features(
         graphs_tuple, global_size)
     self.assertAllEqual((n_graphs, global_size),
                         graphs_tuple.globals.get_shape().as_list())
コード例 #6
0
 def test_fill_edge_state_with_missing_fields_raises(self):
     """Edge field cannot be filled if receivers or senders are missing."""
     for g in self.graphs_dicts_in:
         g.pop("receivers")
         g.pop("senders")
         g.pop("edges")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     with self.assertRaisesRegexp(ValueError, "receivers"):
         graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple,
                                                        edge_size=1)
コード例 #7
0
 def test_fill_node_state(self, node_size):
     """Tests for filling the node state with a constant content."""
     for g in self.graphs_dicts_in:
         g["n_node"] = g["nodes"].shape[0]
         g.pop("nodes")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     n_nodes = np.sum(self.reference_graph.n_node)
     graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size)
     self.assertAllEqual((n_nodes, node_size),
                         graphs_tuple.nodes.get_shape().as_list())
コード例 #8
0
 def test_output_values_larger_rank(self, broadcaster, expected):
     """Test the broadcasted output value."""
     input_graph = utils_tf.data_dicts_to_graphs_tuple(
         [SMALL_GRAPH_1, SMALL_GRAPH_2])
     input_graph = input_graph.map(
         lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
     broadcasted = broadcaster(input_graph)
     with self.test_session() as sess:
         broadcasted_out = sess.run(broadcasted)
     self.assertNDArrayNear(np.reshape(np.array(expected, dtype=np.float32),
                                       [len(expected)] + [2, -1]),
                            broadcasted_out,
                            err=1e-4)
コード例 #9
0
 def setUp(self):
     super(RunGraphWithNoneInSessionTest, self).setUp()
     self._graph = utils_tf.data_dicts_to_graphs_tuple([{
         "senders":
         tf.random_uniform([10], maxval=10, dtype=tf.int32),
         "receivers":
         tf.random_uniform([10], maxval=10, dtype=tf.int32),
         "nodes":
         tf.random_uniform([5, 7]),
         "edges":
         tf.random_uniform([10, 6]),
         "globals":
         tf.random_uniform([1, 8])
     }])
コード例 #10
0
 def setUp(self):
     super(StopGradientsGraphTest, self).setUp()
     self._graph = utils_tf.data_dicts_to_graphs_tuple([{
         "senders":
         tf.zeros([10], dtype=tf.int32),
         "receivers":
         tf.zeros([10], dtype=tf.int32),
         "nodes":
         tf.ones([5, 7]),
         "edges":
         tf.zeros([10, 6]),
         "globals":
         tf.zeros([1, 8])
     }])
コード例 #11
0
 def test_fill_state_user_specified_types(self, dtype):
     """Tests that the features are created with the correct default type."""
     for g in self.graphs_dicts_in:
         g.pop("nodes")
         g.pop("globals")
         g.pop("edges")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype)
     graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype)
     graphs_tuple = utils_tf.set_zero_global_features(
         graphs_tuple, 1, dtype)
     self.assertEqual(dtype, graphs_tuple.edges.dtype)
     self.assertEqual(dtype, graphs_tuple.nodes.dtype)
     self.assertEqual(dtype, graphs_tuple.globals.dtype)
コード例 #12
0
 def test_fill_edge_state_dynamic(self, edge_size):
     """Tests for filling the edge state with a constant content."""
     for g in self.graphs_dicts_in:
         g.pop("edges")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     graphs_tuple = graphs_tuple._replace(
         n_edge=tf.placeholder_with_default(
             graphs_tuple.n_edge, shape=graphs_tuple.n_edge.get_shape()))
     n_edges = np.sum(self.reference_graph.n_edge)
     graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size)
     with self.test_session() as sess:
         actual_edges = sess.run(graphs_tuple.edges)
     self.assertNDArrayNear(np.zeros((n_edges, edge_size)),
                            actual_edges,
                            err=1e-4)
コード例 #13
0
    def test_getitem_one(self):
        index = 2
        expected = self.graphs_dicts_out[index]

        graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
            self.graphs_dicts_in)
        graph_op = utils_tf.get_graph(graphs_tuple, index)
        graph_op = utils_tf.make_runnable_in_session(graph_op)

        with self.test_session() as sess:
            graph = sess.run(graph_op)
        actual, = utils_np.graphs_tuple_to_data_dicts(graph)

        for k, v in expected.items():
            self.assertAllClose(v, actual[k])
        self.assertEqual(expected["nodes"].shape[0], actual["n_node"])
        self.assertEqual(expected["edges"].shape[0], actual["n_edge"])
コード例 #14
0
 def test_data_dicts_to_graphs_tuple_cast_types(self):
     """Index and number fields should be cast to tensors of the right type."""
     for graph_dict in self.graphs_dicts_in:
         graph_dict["n_node"] = np.array(graph_dict["nodes"].shape[0],
                                         dtype=np.int64)
         graph_dict["receivers"] = graph_dict["receivers"].astype(np.int16)
         graph_dict["senders"] = graph_dict["senders"].astype(np.float64)
         graph_dict["nodes"] = graph_dict["nodes"].astype(np.float64)
         graph_dict["edges"] = tf.constant(graph_dict["edges"],
                                           dtype=tf.float64)
     out = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
     for key in ["n_node", "n_edge", "receivers", "senders"]:
         self.assertEqual(tf.int32, getattr(out, key).dtype)
         self.assertEqual(type(tf.int32), type(getattr(out, key).dtype))
     for key in ["nodes", "edges"]:
         self.assertEqual(type(tf.float64), type(getattr(out, key).dtype))
         self.assertEqual(tf.float64, getattr(out, key).dtype)
コード例 #15
0
 def test_fill_state_default_types(self):
     """Tests that the features are created with the correct default type."""
     for g in self.graphs_dicts_in:
         g.pop("nodes")
         g.pop("globals")
         g.pop("edges")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple,
                                                    edge_size=1)
     graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple,
                                                    node_size=1)
     graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple,
                                                      global_size=1)
     self.assertEqual(tf.float32, graphs_tuple.edges.dtype)
     self.assertEqual(tf.float32, graphs_tuple.nodes.dtype)
     self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
コード例 #16
0
 def test_fully_connect_graph_static_with_dynamic_sizes_raises(self):
     for g in self.graphs_dicts_in:
         g.pop("edges")
         g.pop("receivers")
         g.pop("senders")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     graphs_tuple_1 = graphs_tuple.map(test_utils.mask_leading_dimension,
                                       ["n_node"])
     with self.assertRaisesRegexp(ValueError, "known at construction time"):
         utils_tf.fully_connect_graph_static(graphs_tuple_1)
     graphs_tuple_2 = graphs_tuple.map(test_utils.mask_leading_dimension,
                                       ["nodes"])
     with self.assertRaisesRegexp(ValueError, "known at construction time"):
         utils_tf.fully_connect_graph_static(graphs_tuple_2)
     with self.assertRaisesRegexp(ValueError, "the same in all graphs"):
         utils_tf.fully_connect_graph_static(graphs_tuple)
コード例 #17
0
 def test_fill_global_state_dynamic(self, global_size):
     """Tests for filling the global state with a constant content."""
     for g in self.graphs_dicts_in:
         g.pop("globals")
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     # Hide global shape information
     graphs_tuple = graphs_tuple._replace(
         n_node=tf.placeholder_with_default(graphs_tuple.n_node,
                                            shape=[None]))
     n_graphs = self.reference_graph.n_edge.shape[0]
     graphs_tuple = utils_tf.set_zero_global_features(
         graphs_tuple, global_size)
     with self.test_session() as sess:
         actual_globals = sess.run(graphs_tuple.globals)
     self.assertNDArrayNear(np.zeros((n_graphs, global_size)),
                            actual_globals,
                            err=1e-4)
コード例 #18
0
    def test_getitem(self):
        index = slice(1, 3)
        expected = self.graphs_dicts_out[index]

        graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
            self.graphs_dicts_in)
        graphs2_op = utils_tf.get_graph(graphs_tuple, index)
        graphs2_op = utils_tf.make_runnable_in_session(graphs2_op)

        with self.test_session() as sess:
            graphs2 = sess.run(graphs2_op)
        actual = utils_np.graphs_tuple_to_data_dicts(graphs2)

        for ex, ac in zip(expected, actual):
            for k, v in ex.items():
                self.assertAllClose(v, ac[k])
            self.assertEqual(ex["nodes"].shape[0], ac["n_node"])
            self.assertEqual(ex["edges"].shape[0], ac["n_edge"])
コード例 #19
0
 def test_fully_connect_graph_static(self, exclude_self_edges):
     for g in self.graphs_dicts_in:
         g.pop("edges")
         g.pop("receivers")
         g.pop("senders")
     num_graphs = 2
     num_nodes = 3
     if exclude_self_edges:
         n_edges = num_nodes * (num_nodes - 1)
     else:
         n_edges = num_nodes * num_nodes
     n_relation = num_graphs * n_edges
     graphs_dicts = [{
         "nodes": tf.zeros([num_nodes, 1])
     } for _ in range(num_graphs)]
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(graphs_dicts)
     graphs_tuple = utils_tf.fully_connect_graph_static(
         graphs_tuple, exclude_self_edges)
     self.assertAllEqual((n_relation, ),
                         graphs_tuple.receivers.get_shape().as_list())
     self.assertAllEqual((n_relation, ),
                         graphs_tuple.senders.get_shape().as_list())
     self.assertAllEqual((num_graphs, ),
                         graphs_tuple.n_edge.get_shape().as_list())
     with self.test_session() as sess:
         actual_receivers, actual_senders, actual_n_edge = sess.run([
             graphs_tuple.receivers, graphs_tuple.senders,
             graphs_tuple.n_edge
         ])
     expected_edges = []
     offset = 0
     for _ in range(num_graphs):
         for v1 in range(num_nodes):
             for v2 in range(num_nodes):
                 if not exclude_self_edges or v1 != v2:
                     expected_edges.append((v1 + offset, v2 + offset))
         offset += num_nodes
     actual_edges = zip(actual_receivers, actual_senders)
     self.assertNDArrayNear(np.array([n_edges] * num_graphs), actual_n_edge,
                            1e-4)
     self.assertSetEqual(set(actual_edges), set(expected_edges))
コード例 #20
0
 def test_data_dicts_to_graphs_tuple(self, none_fields):
     """Fields in `none_fields` will be cleared out."""
     for field in none_fields:
         for graph_dict in self.graphs_dicts_in:
             if field in graph_dict:
                 if field == "nodes":
                     graph_dict["n_node"] = graph_dict["nodes"].shape[0]
                 graph_dict[field] = None
             self.reference_graph = self.reference_graph._replace(
                 **{field: None})
         if field == "senders":
             self.reference_graph = self.reference_graph._replace(
                 n_edge=np.zeros_like(self.reference_graph.n_edge))
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     for field in none_fields:
         self.assertEqual(None, getattr(graphs_tuple, field))
     graphs_tuple = graphs_tuple.map(tf.no_op, none_fields)
     with self.test_session() as sess:
         self._assert_graph_equals_np(self.reference_graph,
                                      sess.run(graphs_tuple))
コード例 #21
0
    def test_fully_connect_graph_dynamic_with_dynamic_sizes(
            self, exclude_self_edges):
        for g in self.graphs_dicts_in:
            g.pop("edges")
            g.pop("receivers")
            g.pop("senders")
        n_relation = 0
        for g in self.graphs_dicts_in:
            n_node = g["nodes"].shape[0]
            if exclude_self_edges:
                n_relation += n_node * (n_node - 1)
            else:
                n_relation += n_node * n_node

        graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
            self.graphs_dicts_in)
        graphs_tuple = graphs_tuple.map(
            test_utils.mask_leading_dimension,
            ["nodes", "globals", "n_node", "n_edge"])
        graphs_tuple = utils_tf.fully_connect_graph_dynamic(
            graphs_tuple, exclude_self_edges)
        with self.test_session() as sess:
            actual_receivers, actual_senders, actual_n_edge = sess.run([
                graphs_tuple.receivers, graphs_tuple.senders,
                graphs_tuple.n_edge
            ])
        self.assertAllEqual((n_relation, ), actual_receivers.shape)
        self.assertAllEqual((n_relation, ), actual_senders.shape)
        self.assertAllEqual((len(self.graphs_dicts_in), ), actual_n_edge.shape)
        expected_edges = []
        offset = 0
        for graph in self.graphs_dicts_in:
            n_node = graph["nodes"].shape[0]
            for e1 in range(n_node):
                for e2 in range(n_node):
                    if not exclude_self_edges or e1 != e2:
                        expected_edges.append((e1 + offset, e2 + offset))
            offset += n_node
        actual_edges = zip(actual_receivers, actual_senders)
        self.assertSetEqual(set(actual_edges), set(expected_edges))
コード例 #22
0
    def test_fully_connect_graph_dynamic(self, exclude_self_edges):
        for g in self.graphs_dicts_in:
            g.pop("edges")
            g.pop("receivers")
            g.pop("senders")
        n_relation = 0
        for g in self.graphs_dicts_in:
            n_node = g["nodes"].shape[0]
            if exclude_self_edges:
                n_relation += n_node * (n_node - 1)
            else:
                n_relation += n_node * n_node

        graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
            self.graphs_dicts_in)
        graphs_tuple = utils_tf.fully_connect_graph_dynamic(
            graphs_tuple, exclude_self_edges)
        with self.test_session() as sess:
            actual_receivers, actual_senders = sess.run(
                [graphs_tuple.receivers, graphs_tuple.senders])
        self.assertAllEqual((n_relation, ), actual_receivers.shape)
        self.assertAllEqual((n_relation, ), actual_senders.shape)
        self.assertAllEqual((len(self.graphs_dicts_in), ),
                            graphs_tuple.n_edge.get_shape().as_list())
コード例 #23
0
 def test_data_dicts_to_graphs_tuple_no_raise(self):
     """Not having nodes is fine, if the number of nodes is provided."""
     for graph_dict in self.graphs_dicts_in:
         graph_dict["n_node"] = graph_dict["nodes"].shape[0]
         graph_dict["nodes"] = None
     utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
コード例 #24
0
 def test_data_dicts_to_graphs_tuple_raises(self, none_field):
     """Fields that cannot be missing."""
     for graph_dict in self.graphs_dicts_in:
         graph_dict[none_field] = None
     with self.assertRaisesRegexp(ValueError, none_field):
         utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
コード例 #25
0
 def setUp(self):
     super(TestNumGraphs, self).setUp()
     graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
         self.graphs_dicts_in)
     self.empty_graph = graphs_tuple.map(lambda _: None,
                                         graphs.GRAPH_DATA_FIELDS)