def _get_input_graph(self, none_fields=None): if none_fields is None: none_fields = [] input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2, SMALL_GRAPH_3, SMALL_GRAPH_4]) input_graph = input_graph.map(lambda _: None, none_fields) return input_graph
def test_output_values(self, broadcaster, expected): """Test the broadcasted output value.""" input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) broadcasted = broadcaster(input_graph) with self.test_session() as sess: broadcasted_out = sess.run(broadcasted) self.assertNDArrayNear(np.array(expected, dtype=np.float32), broadcasted_out, err=1e-4)
def test_fill_edge_state(self, edge_size): """Tests for filling the edge state with a constant content.""" for g in self.graphs_dicts_in: g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) n_edges = np.sum(self.reference_graph.n_edge) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size) self.assertAllEqual((n_edges, edge_size), graphs_tuple.edges.get_shape().as_list())
def create_data_ops(batch_size, num_elements_min_max): """Returns graphs containg the inputs and targets for classification. Refer to create_data_dicts_tf and create_linked_list_target for more details. Args: batch_size: batch size for the `input_graphs`. num_elements_min_max: a 2-`tuple` of `int`s which define the [lower, upper) range of the number of elements per list. Returns: inputs_op: a `graphs.GraphsTuple` which contains the input list as a graph. targets_op: a `graphs.GraphsTuple` which contains the target as a graph. sort_indices_op: a `graphs.GraphsTuple` which contains the sort indices of the list elements a graph. ranks_op: a `graphs.GraphsTuple` which contains the ranks of the list elements as a graph. data_dicts_to_graphs_tuple: Creates a `graphs.GraphsTuple` containing tensors from data dicts. """ inputs_op, sort_indices_op, ranks_op = create_graph_dicts_tf( batch_size, num_elements_min_max) # show["inputs_graphs"] = inputs_op # show["sort_indices_graphs"] = sort_indices_op # show["ranks_graphs"] = ranks_op inputs_op = utils_tf.data_dicts_to_graphs_tuple(inputs_op) sort_indices_op = utils_tf.data_dicts_to_graphs_tuple(sort_indices_op) ranks_op = utils_tf.data_dicts_to_graphs_tuple(ranks_op) inputs_op = utils_tf.fully_connect_graph_dynamic( inputs_op) # Adds edges to a graph by fully-connecting the nodes. sort_indices_op = utils_tf.fully_connect_graph_dynamic(sort_indices_op) ranks_op = utils_tf.fully_connect_graph_dynamic(ranks_op) targets_op = create_linked_list_target(batch_size, sort_indices_op) nodes = tf.concat((targets_op.nodes, 1.0 - targets_op.nodes), axis=1) edges = tf.concat((targets_op.edges, 1.0 - targets_op.edges), axis=1) targets_op = targets_op._replace(nodes=nodes, edges=edges) return inputs_op, targets_op, sort_indices_op, ranks_op
def test_fill_global_state(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, global_size) self.assertAllEqual((n_graphs, global_size), graphs_tuple.globals.get_shape().as_list())
def test_fill_edge_state_with_missing_fields_raises(self): """Edge field cannot be filled if receivers or senders are missing.""" for g in self.graphs_dicts_in: g.pop("receivers") g.pop("senders") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) with self.assertRaisesRegexp(ValueError, "receivers"): graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1)
def test_fill_node_state(self, node_size): """Tests for filling the node state with a constant content.""" for g in self.graphs_dicts_in: g["n_node"] = g["nodes"].shape[0] g.pop("nodes") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) n_nodes = np.sum(self.reference_graph.n_node) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size) self.assertAllEqual((n_nodes, node_size), graphs_tuple.nodes.get_shape().as_list())
def test_output_values_larger_rank(self, broadcaster, expected): """Test the broadcasted output value.""" input_graph = utils_tf.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) input_graph = input_graph.map( lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1])) broadcasted = broadcaster(input_graph) with self.test_session() as sess: broadcasted_out = sess.run(broadcasted) self.assertNDArrayNear(np.reshape(np.array(expected, dtype=np.float32), [len(expected)] + [2, -1]), broadcasted_out, err=1e-4)
def setUp(self): super(RunGraphWithNoneInSessionTest, self).setUp() self._graph = utils_tf.data_dicts_to_graphs_tuple([{ "senders": tf.random_uniform([10], maxval=10, dtype=tf.int32), "receivers": tf.random_uniform([10], maxval=10, dtype=tf.int32), "nodes": tf.random_uniform([5, 7]), "edges": tf.random_uniform([10, 6]), "globals": tf.random_uniform([1, 8]) }])
def setUp(self): super(StopGradientsGraphTest, self).setUp() self._graph = utils_tf.data_dicts_to_graphs_tuple([{ "senders": tf.zeros([10], dtype=tf.int32), "receivers": tf.zeros([10], dtype=tf.int32), "nodes": tf.ones([5, 7]), "edges": tf.zeros([10, 6]), "globals": tf.zeros([1, 8]) }])
def test_fill_state_user_specified_types(self, dtype): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, 1, dtype) self.assertEqual(dtype, graphs_tuple.edges.dtype) self.assertEqual(dtype, graphs_tuple.nodes.dtype) self.assertEqual(dtype, graphs_tuple.globals.dtype)
def test_fill_edge_state_dynamic(self, edge_size): """Tests for filling the edge state with a constant content.""" for g in self.graphs_dicts_in: g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple = graphs_tuple._replace( n_edge=tf.placeholder_with_default( graphs_tuple.n_edge, shape=graphs_tuple.n_edge.get_shape())) n_edges = np.sum(self.reference_graph.n_edge) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size) with self.test_session() as sess: actual_edges = sess.run(graphs_tuple.edges) self.assertNDArrayNear(np.zeros((n_edges, edge_size)), actual_edges, err=1e-4)
def test_getitem_one(self): index = 2 expected = self.graphs_dicts_out[index] graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graph_op = utils_tf.get_graph(graphs_tuple, index) graph_op = utils_tf.make_runnable_in_session(graph_op) with self.test_session() as sess: graph = sess.run(graph_op) actual, = utils_np.graphs_tuple_to_data_dicts(graph) for k, v in expected.items(): self.assertAllClose(v, actual[k]) self.assertEqual(expected["nodes"].shape[0], actual["n_node"]) self.assertEqual(expected["edges"].shape[0], actual["n_edge"])
def test_data_dicts_to_graphs_tuple_cast_types(self): """Index and number fields should be cast to tensors of the right type.""" for graph_dict in self.graphs_dicts_in: graph_dict["n_node"] = np.array(graph_dict["nodes"].shape[0], dtype=np.int64) graph_dict["receivers"] = graph_dict["receivers"].astype(np.int16) graph_dict["senders"] = graph_dict["senders"].astype(np.float64) graph_dict["nodes"] = graph_dict["nodes"].astype(np.float64) graph_dict["edges"] = tf.constant(graph_dict["edges"], dtype=tf.float64) out = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) for key in ["n_node", "n_edge", "receivers", "senders"]: self.assertEqual(tf.int32, getattr(out, key).dtype) self.assertEqual(type(tf.int32), type(getattr(out, key).dtype)) for key in ["nodes", "edges"]: self.assertEqual(type(tf.float64), type(getattr(out, key).dtype)) self.assertEqual(tf.float64, getattr(out, key).dtype)
def test_fill_state_default_types(self): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1) graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, global_size=1) self.assertEqual(tf.float32, graphs_tuple.edges.dtype) self.assertEqual(tf.float32, graphs_tuple.nodes.dtype) self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
def test_fully_connect_graph_static_with_dynamic_sizes_raises(self): for g in self.graphs_dicts_in: g.pop("edges") g.pop("receivers") g.pop("senders") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple_1 = graphs_tuple.map(test_utils.mask_leading_dimension, ["n_node"]) with self.assertRaisesRegexp(ValueError, "known at construction time"): utils_tf.fully_connect_graph_static(graphs_tuple_1) graphs_tuple_2 = graphs_tuple.map(test_utils.mask_leading_dimension, ["nodes"]) with self.assertRaisesRegexp(ValueError, "known at construction time"): utils_tf.fully_connect_graph_static(graphs_tuple_2) with self.assertRaisesRegexp(ValueError, "the same in all graphs"): utils_tf.fully_connect_graph_static(graphs_tuple)
def test_fill_global_state_dynamic(self, global_size): """Tests for filling the global state with a constant content.""" for g in self.graphs_dicts_in: g.pop("globals") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) # Hide global shape information graphs_tuple = graphs_tuple._replace( n_node=tf.placeholder_with_default(graphs_tuple.n_node, shape=[None])) n_graphs = self.reference_graph.n_edge.shape[0] graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, global_size) with self.test_session() as sess: actual_globals = sess.run(graphs_tuple.globals) self.assertNDArrayNear(np.zeros((n_graphs, global_size)), actual_globals, err=1e-4)
def test_getitem(self): index = slice(1, 3) expected = self.graphs_dicts_out[index] graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs2_op = utils_tf.get_graph(graphs_tuple, index) graphs2_op = utils_tf.make_runnable_in_session(graphs2_op) with self.test_session() as sess: graphs2 = sess.run(graphs2_op) actual = utils_np.graphs_tuple_to_data_dicts(graphs2) for ex, ac in zip(expected, actual): for k, v in ex.items(): self.assertAllClose(v, ac[k]) self.assertEqual(ex["nodes"].shape[0], ac["n_node"]) self.assertEqual(ex["edges"].shape[0], ac["n_edge"])
def test_fully_connect_graph_static(self, exclude_self_edges): for g in self.graphs_dicts_in: g.pop("edges") g.pop("receivers") g.pop("senders") num_graphs = 2 num_nodes = 3 if exclude_self_edges: n_edges = num_nodes * (num_nodes - 1) else: n_edges = num_nodes * num_nodes n_relation = num_graphs * n_edges graphs_dicts = [{ "nodes": tf.zeros([num_nodes, 1]) } for _ in range(num_graphs)] graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(graphs_dicts) graphs_tuple = utils_tf.fully_connect_graph_static( graphs_tuple, exclude_self_edges) self.assertAllEqual((n_relation, ), graphs_tuple.receivers.get_shape().as_list()) self.assertAllEqual((n_relation, ), graphs_tuple.senders.get_shape().as_list()) self.assertAllEqual((num_graphs, ), graphs_tuple.n_edge.get_shape().as_list()) with self.test_session() as sess: actual_receivers, actual_senders, actual_n_edge = sess.run([ graphs_tuple.receivers, graphs_tuple.senders, graphs_tuple.n_edge ]) expected_edges = [] offset = 0 for _ in range(num_graphs): for v1 in range(num_nodes): for v2 in range(num_nodes): if not exclude_self_edges or v1 != v2: expected_edges.append((v1 + offset, v2 + offset)) offset += num_nodes actual_edges = zip(actual_receivers, actual_senders) self.assertNDArrayNear(np.array([n_edges] * num_graphs), actual_n_edge, 1e-4) self.assertSetEqual(set(actual_edges), set(expected_edges))
def test_data_dicts_to_graphs_tuple(self, none_fields): """Fields in `none_fields` will be cleared out.""" for field in none_fields: for graph_dict in self.graphs_dicts_in: if field in graph_dict: if field == "nodes": graph_dict["n_node"] = graph_dict["nodes"].shape[0] graph_dict[field] = None self.reference_graph = self.reference_graph._replace( **{field: None}) if field == "senders": self.reference_graph = self.reference_graph._replace( n_edge=np.zeros_like(self.reference_graph.n_edge)) graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) for field in none_fields: self.assertEqual(None, getattr(graphs_tuple, field)) graphs_tuple = graphs_tuple.map(tf.no_op, none_fields) with self.test_session() as sess: self._assert_graph_equals_np(self.reference_graph, sess.run(graphs_tuple))
def test_fully_connect_graph_dynamic_with_dynamic_sizes( self, exclude_self_edges): for g in self.graphs_dicts_in: g.pop("edges") g.pop("receivers") g.pop("senders") n_relation = 0 for g in self.graphs_dicts_in: n_node = g["nodes"].shape[0] if exclude_self_edges: n_relation += n_node * (n_node - 1) else: n_relation += n_node * n_node graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple = graphs_tuple.map( test_utils.mask_leading_dimension, ["nodes", "globals", "n_node", "n_edge"]) graphs_tuple = utils_tf.fully_connect_graph_dynamic( graphs_tuple, exclude_self_edges) with self.test_session() as sess: actual_receivers, actual_senders, actual_n_edge = sess.run([ graphs_tuple.receivers, graphs_tuple.senders, graphs_tuple.n_edge ]) self.assertAllEqual((n_relation, ), actual_receivers.shape) self.assertAllEqual((n_relation, ), actual_senders.shape) self.assertAllEqual((len(self.graphs_dicts_in), ), actual_n_edge.shape) expected_edges = [] offset = 0 for graph in self.graphs_dicts_in: n_node = graph["nodes"].shape[0] for e1 in range(n_node): for e2 in range(n_node): if not exclude_self_edges or e1 != e2: expected_edges.append((e1 + offset, e2 + offset)) offset += n_node actual_edges = zip(actual_receivers, actual_senders) self.assertSetEqual(set(actual_edges), set(expected_edges))
def test_fully_connect_graph_dynamic(self, exclude_self_edges): for g in self.graphs_dicts_in: g.pop("edges") g.pop("receivers") g.pop("senders") n_relation = 0 for g in self.graphs_dicts_in: n_node = g["nodes"].shape[0] if exclude_self_edges: n_relation += n_node * (n_node - 1) else: n_relation += n_node * n_node graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) graphs_tuple = utils_tf.fully_connect_graph_dynamic( graphs_tuple, exclude_self_edges) with self.test_session() as sess: actual_receivers, actual_senders = sess.run( [graphs_tuple.receivers, graphs_tuple.senders]) self.assertAllEqual((n_relation, ), actual_receivers.shape) self.assertAllEqual((n_relation, ), actual_senders.shape) self.assertAllEqual((len(self.graphs_dicts_in), ), graphs_tuple.n_edge.get_shape().as_list())
def test_data_dicts_to_graphs_tuple_no_raise(self): """Not having nodes is fine, if the number of nodes is provided.""" for graph_dict in self.graphs_dicts_in: graph_dict["n_node"] = graph_dict["nodes"].shape[0] graph_dict["nodes"] = None utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
def test_data_dicts_to_graphs_tuple_raises(self, none_field): """Fields that cannot be missing.""" for graph_dict in self.graphs_dicts_in: graph_dict[none_field] = None with self.assertRaisesRegexp(ValueError, none_field): utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
def setUp(self): super(TestNumGraphs, self).setUp() graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( self.graphs_dicts_in) self.empty_graph = graphs_tuple.map(lambda _: None, graphs.GRAPH_DATA_FIELDS)