示例#1
0
 def test_feed_data(self):
     networkx = [_generate_graph(batch_index) for batch_index in range(16)]
     placeholders = utils_tf.placeholders_from_networkxs(
         networkx, force_dynamic_num_graphs=True)
     # Does not need to be the same size
     networkxs = [_generate_graph(batch_index) for batch_index in range(2)]
     with self.test_session() as sess:
         output = sess.run(
             placeholders,
             utils_tf.get_feed_dict(
                 placeholders,
                 utils_np.networkxs_to_graphs_tuple(networkxs)))
     self.assertAllEqual(
         np.array([[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 1],
                   [3, 1]]), output.nodes)
     self.assertEqual(np.float32, output.nodes.dtype)
     self.assertAllEqual(np.array([[0], [1]]), output.globals)
     self.assertEqual(np.float32, output.globals.dtype)
     sorted_edges_content = sorted([
         (x, y, z)
         for x, y, z in zip(output.receivers, output.senders, output.edges)
     ])
     self.assertAllEqual([0, 0, 1, 4, 4, 5],
                         [x[0] for x in sorted_edges_content])
     self.assertAllEqual([1, 2, 3, 5, 6, 7],
                         [x[1] for x in sorted_edges_content])
     self.assertEqual(np.float64, output.edges.dtype)
     self.assertAllEqual(
         np.array([[0, 1, 0], [1, 2, 0], [2, 3, 0], [0, 1, 1], [1, 2, 1],
                   [2, 3, 1]]), [x[2] for x in sorted_edges_content])
示例#2
0
 def test_get_feed_dict_raises(self, none_fields):
     networkxs = [_generate_graph(batch_index) for batch_index in range(16)]
     placeholders = utils_tf.placeholders_from_networkxs(networkxs)
     feed_values = utils_np.networkxs_to_graphs_tuple(networkxs)
     with self.assertRaisesRegexp(ValueError, ""):
         utils_tf.get_feed_dict(
             placeholders.map(lambda _: None, none_fields), feed_values)
     with self.assertRaisesRegexp(ValueError, ""):
         utils_tf.get_feed_dict(
             placeholders, feed_values.map(lambda _: None, none_fields))
示例#3
0
 def test_placeholders_from_networkxs(self):
     num_graphs = 16
     networkxs = [
         _generate_graph(batch_index) for batch_index in range(num_graphs)
     ]
     placeholders = utils_tf.placeholders_from_networkxs(
         networkxs, force_dynamic_num_graphs=False)
     self._assert_expected_shapes(placeholders, num_graphs=num_graphs)
     self.assertEqual(tf.float32, placeholders.nodes.dtype)
     self.assertEqual(tf.float64, placeholders.edges.dtype)
示例#4
0
 def test_placeholders_from_networkxs_missing_edges(self):
     num_graphs = 16
     networkxs = [
         _generate_graph(batch_index, add_edges=False)
         for batch_index in range(num_graphs)
     ]
     placeholders = utils_tf.placeholders_from_networkxs(
         networkxs, force_dynamic_num_graphs=False)
     self.assertEqual(None, placeholders.edges)
     self._assert_expected_shapes(placeholders,
                                  but_for=["edges"],
                                  num_graphs=num_graphs)
示例#5
0
 def test_placeholders_from_networkxs_hints(self):
     num_graphs = 16
     networkxs = [
         _generate_graph(batch_index, n_nodes=0, add_edges=False)
         for batch_index in range(num_graphs)
     ]
     placeholders = utils_tf.placeholders_from_networkxs(
         networkxs,
         node_shape_hint=[14],
         edge_shape_hint=[17],
         data_type_hint=tf.float64,
         force_dynamic_num_graphs=False)
     self.assertAllEqual([None, 14], placeholders.nodes.shape.as_list())
     self.assertAllEqual([None, 17], placeholders.edges.shape.as_list())
     self._assert_expected_shapes(placeholders,
                                  but_for=["nodes", "edges"],
                                  num_graphs=num_graphs)
     self.assertEqual(tf.float64, placeholders.nodes.dtype)
     self.assertEqual(tf.float64, placeholders.edges.dtype)
示例#6
0
 def test_feed_data_no_nodes(self):
     networkx = [
         _generate_graph(batch_index, n_nodes=0, add_edges=False)
         for batch_index in range(16)
     ]
     placeholders = utils_tf.placeholders_from_networkxs(
         networkx, force_dynamic_num_graphs=True)
     # Does not need to be the same size
     networkxs = [
         _generate_graph(batch_index, n_nodes=0, add_edges=False)
         for batch_index in range(2)
     ]
     self.assertEqual(None, placeholders.nodes)
     self.assertEqual(None, placeholders.edges)
     with self.test_session() as sess:
         output = sess.run(
             placeholders.replace(nodes=tf.no_op(), edges=tf.no_op()),
             utils_tf.get_feed_dict(
                 placeholders,
                 utils_np.networkxs_to_graphs_tuple(networkxs)))
     self.assertAllEqual(np.array([[0], [1]]), output.globals)
     self.assertEqual(np.float32, output.globals.dtype)