def test_feed_data(self): networkx = [_generate_graph(batch_index) for batch_index in range(16)] placeholders = utils_tf.placeholders_from_networkxs( networkx, force_dynamic_num_graphs=True) # Does not need to be the same size networkxs = [_generate_graph(batch_index) for batch_index in range(2)] with self.test_session() as sess: output = sess.run( placeholders, utils_tf.get_feed_dict( placeholders, utils_np.networkxs_to_graphs_tuple(networkxs))) self.assertAllEqual( np.array([[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 1], [3, 1]]), output.nodes) self.assertEqual(np.float32, output.nodes.dtype) self.assertAllEqual(np.array([[0], [1]]), output.globals) self.assertEqual(np.float32, output.globals.dtype) sorted_edges_content = sorted([ (x, y, z) for x, y, z in zip(output.receivers, output.senders, output.edges) ]) self.assertAllEqual([0, 0, 1, 4, 4, 5], [x[0] for x in sorted_edges_content]) self.assertAllEqual([1, 2, 3, 5, 6, 7], [x[1] for x in sorted_edges_content]) self.assertEqual(np.float64, output.edges.dtype) self.assertAllEqual( np.array([[0, 1, 0], [1, 2, 0], [2, 3, 0], [0, 1, 1], [1, 2, 1], [2, 3, 1]]), [x[2] for x in sorted_edges_content])
def test_get_feed_dict_raises(self, none_fields): networkxs = [_generate_graph(batch_index) for batch_index in range(16)] placeholders = utils_tf.placeholders_from_networkxs(networkxs) feed_values = utils_np.networkxs_to_graphs_tuple(networkxs) with self.assertRaisesRegexp(ValueError, ""): utils_tf.get_feed_dict( placeholders.map(lambda _: None, none_fields), feed_values) with self.assertRaisesRegexp(ValueError, ""): utils_tf.get_feed_dict( placeholders, feed_values.map(lambda _: None, none_fields))
def test_placeholders_from_networkxs(self): num_graphs = 16 networkxs = [ _generate_graph(batch_index) for batch_index in range(num_graphs) ] placeholders = utils_tf.placeholders_from_networkxs( networkxs, force_dynamic_num_graphs=False) self._assert_expected_shapes(placeholders, num_graphs=num_graphs) self.assertEqual(tf.float32, placeholders.nodes.dtype) self.assertEqual(tf.float64, placeholders.edges.dtype)
def test_placeholders_from_networkxs_missing_edges(self): num_graphs = 16 networkxs = [ _generate_graph(batch_index, add_edges=False) for batch_index in range(num_graphs) ] placeholders = utils_tf.placeholders_from_networkxs( networkxs, force_dynamic_num_graphs=False) self.assertEqual(None, placeholders.edges) self._assert_expected_shapes(placeholders, but_for=["edges"], num_graphs=num_graphs)
def test_placeholders_from_networkxs_hints(self): num_graphs = 16 networkxs = [ _generate_graph(batch_index, n_nodes=0, add_edges=False) for batch_index in range(num_graphs) ] placeholders = utils_tf.placeholders_from_networkxs( networkxs, node_shape_hint=[14], edge_shape_hint=[17], data_type_hint=tf.float64, force_dynamic_num_graphs=False) self.assertAllEqual([None, 14], placeholders.nodes.shape.as_list()) self.assertAllEqual([None, 17], placeholders.edges.shape.as_list()) self._assert_expected_shapes(placeholders, but_for=["nodes", "edges"], num_graphs=num_graphs) self.assertEqual(tf.float64, placeholders.nodes.dtype) self.assertEqual(tf.float64, placeholders.edges.dtype)
def test_feed_data_no_nodes(self): networkx = [ _generate_graph(batch_index, n_nodes=0, add_edges=False) for batch_index in range(16) ] placeholders = utils_tf.placeholders_from_networkxs( networkx, force_dynamic_num_graphs=True) # Does not need to be the same size networkxs = [ _generate_graph(batch_index, n_nodes=0, add_edges=False) for batch_index in range(2) ] self.assertEqual(None, placeholders.nodes) self.assertEqual(None, placeholders.edges) with self.test_session() as sess: output = sess.run( placeholders.replace(nodes=tf.no_op(), edges=tf.no_op()), utils_tf.get_feed_dict( placeholders, utils_np.networkxs_to_graphs_tuple(networkxs))) self.assertAllEqual(np.array([[0], [1]]), output.globals) self.assertEqual(np.float32, output.globals.dtype)