Esempio n. 1
0
async def db():
    """Initializes user alice with the following active resources:
    - 1 cluster
    - 5 cores
    - 3 GiB memory
    """
    db = objects.DataManager()
    db.load_database_state()

    alice = db.get_or_create_user("alice")

    # Create one stopped cluster
    c = db.create_cluster(alice, {}, memory=GiB, cores=1)
    for _ in range(2):
        w = db.create_worker(c, memory=GiB, cores=2)
        db.update_worker(w,
                         status=objects.WorkerStatus.STOPPED,
                         stop_time=objects.timestamp())
        db.update_cluster(c,
                          status=objects.ClusterStatus.STOPPED,
                          stop_time=objects.timestamp())

    # Create one active cluster with 2 active workers, 1 stopped worker
    c = db.create_cluster(alice, {}, memory=GiB, cores=1)
    for i in range(3):
        w = db.create_worker(c, memory=GiB, cores=2)
        if i == 0:
            db.update_worker(w,
                             status=objects.WorkerStatus.STOPPED,
                             stop_time=objects.timestamp())
    return db
Esempio n. 2
0
async def test_encryption(tmpdir):
    db_url = "sqlite:///%s" % tmpdir.join("dask_gateway.sqlite")
    encrypt_keys = [Fernet.generate_key() for i in range(3)]
    db = objects.DataManager(url=db_url, encrypt_keys=encrypt_keys)
    db.load_database_state()

    assert db.fernet is not None

    data = b"my secret data"
    encrypted = db.encrypt(data)
    assert encrypted != data

    data2 = db.decrypt(encrypted)
    assert data == data2

    alice = db.get_or_create_user("alice")
    c = db.create_cluster(alice, {})
    assert c.tls_cert is not None
    assert c.tls_key is not None

    # Check database state is encrypted
    with db.db.begin() as conn:
        res = conn.execute(
            objects.clusters.select(objects.clusters.c.id == c.id)
        ).fetchone()
    assert res.tls_credentials != b";".join((c.tls_cert, c.tls_key))
    cert, key = db.decrypt(res.tls_credentials).split(b";")
    token = db.decrypt(res.token).decode()
    assert cert == c.tls_cert
    assert key == c.tls_key
    assert token == c.token

    # Check can reload database with keys
    db2 = objects.DataManager(url=db_url, encrypt_keys=encrypt_keys)
    db2.load_database_state()
    c2 = db2.id_to_cluster[c.id]
    assert c2.tls_cert == c.tls_cert
    assert c2.tls_key == c.tls_key
    assert c2.token == c.token
Esempio n. 3
0
async def test_cleanup_expired_clusters(monkeypatch):
    db = objects.DataManager()
    db.load_database_state()

    alice = db.get_or_create_user("alice")

    current_time = time.time()

    def mytime():
        nonlocal current_time
        current_time += 0.5
        return current_time

    monkeypatch.setattr(time, "time", mytime)

    def add_cluster(stop=True):
        c = db.create_cluster(alice, {}, memory=1e9, cores=1)
        for _ in range(5):
            w = db.create_worker(c, memory=2e9, cores=2)
            if stop:
                db.update_worker(
                    w,
                    status=objects.WorkerStatus.STOPPED,
                    stop_time=objects.timestamp(),
                )
        if stop:
            db.update_cluster(c,
                              status=objects.ClusterStatus.STOPPED,
                              stop_time=objects.timestamp())
        return c

    add_cluster(stop=True)  # c1
    add_cluster(stop=True)  # c2
    c3 = add_cluster(stop=False)

    cutoff = mytime()

    c4 = add_cluster(stop=True)
    c5 = add_cluster(stop=False)

    check_consistency(db)

    # Set time to always return same value
    now = mytime()
    monkeypatch.setattr(time, "time", lambda: now)

    # 2 clusters are expired
    max_age = now - cutoff
    n = db.cleanup_expired(max_age)
    assert n == 2

    check_consistency(db)

    # c3, c4, c5 are all that remains
    assert set(db.id_to_cluster) == {c3.id, c4.id, c5.id}

    # Running again expires no clusters
    max_age = now - cutoff
    n = db.cleanup_expired(max_age)
    assert n == 0

    check_consistency(db)