async def db(): """Initializes user alice with the following active resources: - 1 cluster - 5 cores - 3 GiB memory """ db = objects.DataManager() db.load_database_state() alice = db.get_or_create_user("alice") # Create one stopped cluster c = db.create_cluster(alice, {}, memory=GiB, cores=1) for _ in range(2): w = db.create_worker(c, memory=GiB, cores=2) db.update_worker(w, status=objects.WorkerStatus.STOPPED, stop_time=objects.timestamp()) db.update_cluster(c, status=objects.ClusterStatus.STOPPED, stop_time=objects.timestamp()) # Create one active cluster with 2 active workers, 1 stopped worker c = db.create_cluster(alice, {}, memory=GiB, cores=1) for i in range(3): w = db.create_worker(c, memory=GiB, cores=2) if i == 0: db.update_worker(w, status=objects.WorkerStatus.STOPPED, stop_time=objects.timestamp()) return db
async def test_encryption(tmpdir): db_url = "sqlite:///%s" % tmpdir.join("dask_gateway.sqlite") encrypt_keys = [Fernet.generate_key() for i in range(3)] db = objects.DataManager(url=db_url, encrypt_keys=encrypt_keys) db.load_database_state() assert db.fernet is not None data = b"my secret data" encrypted = db.encrypt(data) assert encrypted != data data2 = db.decrypt(encrypted) assert data == data2 alice = db.get_or_create_user("alice") c = db.create_cluster(alice, {}) assert c.tls_cert is not None assert c.tls_key is not None # Check database state is encrypted with db.db.begin() as conn: res = conn.execute( objects.clusters.select(objects.clusters.c.id == c.id) ).fetchone() assert res.tls_credentials != b";".join((c.tls_cert, c.tls_key)) cert, key = db.decrypt(res.tls_credentials).split(b";") token = db.decrypt(res.token).decode() assert cert == c.tls_cert assert key == c.tls_key assert token == c.token # Check can reload database with keys db2 = objects.DataManager(url=db_url, encrypt_keys=encrypt_keys) db2.load_database_state() c2 = db2.id_to_cluster[c.id] assert c2.tls_cert == c.tls_cert assert c2.tls_key == c.tls_key assert c2.token == c.token
async def test_cleanup_expired_clusters(monkeypatch): db = objects.DataManager() db.load_database_state() alice = db.get_or_create_user("alice") current_time = time.time() def mytime(): nonlocal current_time current_time += 0.5 return current_time monkeypatch.setattr(time, "time", mytime) def add_cluster(stop=True): c = db.create_cluster(alice, {}, memory=1e9, cores=1) for _ in range(5): w = db.create_worker(c, memory=2e9, cores=2) if stop: db.update_worker( w, status=objects.WorkerStatus.STOPPED, stop_time=objects.timestamp(), ) if stop: db.update_cluster(c, status=objects.ClusterStatus.STOPPED, stop_time=objects.timestamp()) return c add_cluster(stop=True) # c1 add_cluster(stop=True) # c2 c3 = add_cluster(stop=False) cutoff = mytime() c4 = add_cluster(stop=True) c5 = add_cluster(stop=False) check_consistency(db) # Set time to always return same value now = mytime() monkeypatch.setattr(time, "time", lambda: now) # 2 clusters are expired max_age = now - cutoff n = db.cleanup_expired(max_age) assert n == 2 check_consistency(db) # c3, c4, c5 are all that remains assert set(db.id_to_cluster) == {c3.id, c4.id, c5.id} # Running again expires no clusters max_age = now - cutoff n = db.cleanup_expired(max_age) assert n == 0 check_consistency(db)