Beispiel #1
0
 def test_serialize(self):
     pool_model_a = Pool(pool="test_pool_a", slots=3)
     pool_model_b = Pool(pool="test_pool_b", slots=3)
     instance = PoolCollection(pools=[pool_model_a, pool_model_b],
                               total_entries=2)
     assert {
         "pools": [
             {
                 "name": "test_pool_a",
                 "slots": 3,
                 "occupied_slots": 0,
                 "running_slots": 0,
                 "queued_slots": 0,
                 "open_slots": 3,
             },
             {
                 "name": "test_pool_b",
                 "slots": 3,
                 "occupied_slots": 0,
                 "running_slots": 0,
                 "queued_slots": 0,
                 "open_slots": 3,
             },
         ],
         "total_entries":
         2,
     } == pool_collection_schema.dump(instance)
Beispiel #2
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        assert 5 == Pool.get_default_pool().open_slots()

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE,
        )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        assert 2 == Pool.get_default_pool().open_slots()
        assert {
            "default_pool": {
                "open": 2,
                "queued": 2,
                "total": 5,
                "running": 1,
            }
        } == Pool.slots_stats()
Beispiel #3
0
def add_default_pool_if_not_exists(session=None):
    from airflow.models.pool import Pool
    if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
        default_pool = Pool(
            pool=Pool.DEFAULT_POOL_NAME,
            slots=conf.getint(section='core', key='non_pooled_task_slot_count',
                              fallback=128),
            description="Default pool",
        )
        session.add(default_pool)
        session.commit()
Beispiel #4
0
def add_default_pool_if_not_exists(session=None):
    """
    Add default pool if it does not exist.
    """
    if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
        default_pool = Pool(
            pool=Pool.DEFAULT_POOL_NAME,
            slots=conf.getint(section='core', key='non_pooled_task_slot_count',
                              fallback=128),
            description="Default pool",
        )
        session.add(default_pool)
        session.commit()
Beispiel #5
0
 def test_response_200(self, session):
     pool_model = Pool(pool="test_pool_a", slots=3)
     session.add(pool_model)
     session.commit()
     result = session.query(Pool).all()
     assert len(result) == 2  # accounts for the default pool as well
     response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     self.assertEqual(
         {
             "pools": [
                 {
                     "name": "default_pool",
                     "slots": 128,
                     "occupied_slots": 0,
                     "running_slots": 0,
                     "queued_slots": 0,
                     "open_slots": 128,
                 },
                 {
                     "name": "test_pool_a",
                     "slots": 3,
                     "occupied_slots": 0,
                     "running_slots": 0,
                     "queued_slots": 0,
                     "open_slots": 3,
                 },
             ],
             "total_entries": 2,
         },
         response.json,
     )
Beispiel #6
0
    def _validate_task_pools(self,
                             *,
                             dagbag: DagBag,
                             session: Session = NEW_SESSION):
        """
        Validates and raise exception if any task in a dag is using a non-existent pool
        :meta private:
        """
        from airflow.models.pool import Pool

        def check_pools(dag):
            task_pools = {task.pool for task in dag.tasks}
            nonexistent_pools = task_pools - pools
            if nonexistent_pools:
                return (
                    f"Dag '{dag.dag_id}' references non-existent pools: {list(sorted(nonexistent_pools))!r}"
                )

        pools = {p.pool for p in Pool.get_pools(session)}
        for dag in dagbag.dags.values():
            message = check_pools(dag)
            if message:
                self.dag_warnings.add(
                    DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL,
                               message))
            for subdag in dag.subdags:
                message = check_pools(subdag)
                if message:
                    self.dag_warnings.add(
                        DagWarning(subdag.dag_id,
                                   DagWarningType.NONEXISTENT_POOL, message))
Beispiel #7
0
 def test_should_respect_page_size_limit_default(self, session):
     pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)]
     session.add_all(pools)
     session.commit()
     result = session.query(Pool).count()
     self.assertEqual(result, 121)
     response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     self.assertEqual(len(response.json['pools']), 100)
Beispiel #8
0
 def test_should_return_conf_max_if_req_max_above_conf(self, session):
     pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 200)]
     session.add_all(pools)
     session.commit()
     result = session.query(Pool).count()
     self.assertEqual(result, 200)
     response = self.client.get("/api/v1/pools?limit=180", environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     self.assertEqual(len(response.json['pools']), 150)
Beispiel #9
0
    def test_should_raises_401_unauthenticated(self, session):
        pool = Pool(pool="test_pool", slots=2)
        session.add(pool)
        session.commit()

        response = self.client.patch(
            "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3},
        )

        assert_401(response)
Beispiel #10
0
 def test_limit_and_offset(self, url, expected_pool_ids, session):
     pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)]
     session.add_all(pools)
     session.commit()
     result = session.query(Pool).count()
     self.assertEqual(result, 121)  # accounts for default pool as well
     response = self.client.get(url, environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     pool_ids = [pool["name"] for pool in response.json["pools"]]
     self.assertEqual(pool_ids, expected_pool_ids)
Beispiel #11
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
Beispiel #12
0
    def test_response_204(self, session):
        pool_name = "test_pool"
        pool_instance = Pool(pool=pool_name, slots=3)
        session.add(pool_instance)
        session.commit()

        response = self.client.delete(f"api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"})
        assert response.status_code == 204
        # Check if the pool is deleted from the db
        response = self.client.get(f"api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"})
        self.assertEqual(response.status_code, 404)
Beispiel #13
0
    def test_open_slots(self):
        pool = Pool(pool='test_pool', slots=5)
        dag = DAG(
            dag_id='test_open_slots',
            start_date=DEFAULT_DATE, )
        t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(3, pool.open_slots())
Beispiel #14
0
    def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        self.assertEqual(5, Pool.get_default_pool().open_slots())

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(2, Pool.get_default_pool().open_slots())
Beispiel #15
0
 def test_should_raise_400_for_invalid_orderby(self, session):
     pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)]
     session.add_all(pools)
     session.commit()
     result = session.query(Pool).count()
     assert result == 121
     response = self.client.get(
         "/api/v1/pools?order_by=open_slots", environ_overrides={'REMOTE_USER': "******"}
     )
     assert response.status_code == 400
     msg = "Ordering with 'open_slots' is disallowed or the attribute does not exist on the model"
     assert response.json['detail'] == msg
Beispiel #16
0
    def test_should_raises_401_unauthenticated(self, session):
        pool_name = "test_pool"
        pool_instance = Pool(pool=pool_name, slots=3)
        session.add(pool_instance)
        session.commit()

        response = self.client.delete(f"api/v1/pools/{pool_name}")

        assert_401(response)

        # Should still exists
        response = self.client.get(f"/api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"})
        assert response.status_code == 200
Beispiel #17
0
 def test_response_400(self, name, error_detail, url, patch_json, session):
     del name
     pool = Pool(pool="test_pool", slots=3)
     session.add(pool)
     session.commit()
     response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 400
     assert {
         "detail": error_detail,
         "status": 400,
         "title": "Bad Request",
         "type": EXCEPTIONS_LINK_MAP[400],
     } == response.json
Beispiel #18
0
 def setUp(self):
     clear_db_pools()
     self.pools = [Pool.get_default_pool()]
     for i in range(self.USER_POOL_COUNT):
         name = f'experimental_{i + 1}'
         pool = models.Pool(
             pool=name,
             slots=i,
             description=name,
         )
         self.pools.append(pool)
     with create_session() as session:
         session.add_all(self.pools)
Beispiel #19
0
 def test_response_200(self, session):
     pool_model = Pool(pool="test_pool_a", slots=3)
     session.add(pool_model)
     session.commit()
     response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     assert {
         "name": "test_pool_a",
         "slots": 3,
         "occupied_slots": 0,
         "running_slots": 0,
         "queued_slots": 0,
         "open_slots": 3,
     } == response.json
Beispiel #20
0
 def test_response_200(self, url, patch_json, expected_name, expected_slots, session):
     pool = Pool(pool="test_pool", slots=3)
     session.add(pool)
     session.commit()
     response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 200
     assert {
         "name": expected_name,
         "slots": expected_slots,
         "occupied_slots": 0,
         "running_slots": 0,
         "queued_slots": 0,
         "open_slots": expected_slots,
     } == response.json
Beispiel #21
0
 def setUp(self):
     self.session = settings.Session()
     clear_db_pools()
     self.pools = [Pool.get_default_pool()]
     for i in range(self.USER_POOL_COUNT):
         name = 'experimental_%s' % (i + 1)
         pool = models.Pool(
             pool=name,
             slots=i,
             description=name,
         )
         self.session.add(pool)
         self.pools.append(pool)
     self.session.commit()
Beispiel #22
0
 def test_serialize(self, session):
     pool_model = Pool(pool="test_pool", slots=2)
     session.add(pool_model)
     session.commit()
     pool_instance = session.query(Pool).filter(
         Pool.pool == pool_model.pool).first()
     serialized_pool = pool_schema.dump(pool_instance)
     assert serialized_pool == {
         "name": "test_pool",
         "slots": 2,
         "occupied_slots": 0,
         "running_slots": 0,
         "queued_slots": 0,
         "open_slots": 2,
     }
Beispiel #23
0
 def create_pool(self, name, slots, description):
     if not (name and name.strip()):
         raise AirflowBadRequest("Pool name shouldn't be empty")
     pool_name_length = Pool.pool.property.columns[0].type.length
     if len(name) > pool_name_length:
         raise AirflowBadRequest(
             f"pool name cannot be more than {pool_name_length} characters")
     try:
         slots = int(slots)
     except ValueError:
         raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
     pool = Pool.create_or_update_pool(name=name,
                                       slots=slots,
                                       description=description)
     return pool.pool, pool.slots, pool.description
Beispiel #24
0
 def test_response_400(self, error_detail, request_json, session):
     pool = Pool(pool="test_pool", slots=2)
     session.add(pool)
     session.commit()
     response = self.client.patch("api/v1/pools/test_pool",
                                  json=request_json)
     assert response.status_code == 400
     self.assertEqual(
         {
             "detail": error_detail,
             "status": 400,
             "title": "Bad request",
             "type": "about:blank",
         },
         response.json,
     )
Beispiel #25
0
 def test_response_409(self, session):
     pool_name = "test_pool_a"
     pool_instance = Pool(pool=pool_name, slots=3)
     session.add(pool_instance)
     session.commit()
     response = self.client.post(
         "api/v1/pools",
         json={"name": "test_pool_a", "slots": 3},
         environ_overrides={'REMOTE_USER': "******"},
     )
     assert response.status_code == 409
     assert {
         "detail": f"Pool: {pool_name} already exists",
         "status": 409,
         "title": "Conflict",
         "type": EXCEPTIONS_LINK_MAP[409],
     } == response.json
Beispiel #26
0
 def test_response_200(self, session):
     pool_model = Pool(pool="test_pool_a", slots=3)
     session.add(pool_model)
     session.commit()
     response = self.client.get("/api/v1/pools/test_pool_a")
     assert response.status_code == 200
     self.assertEqual(
         {
             "name": "test_pool_a",
             "slots": 3,
             "occupied_slots": 0,
             "running_slots": 0,
             "queued_slots": 0,
             "open_slots": 3,
         },
         response.json,
     )
Beispiel #27
0
 def test_response_400(self, name, error_detail, url, patch_json, session):
     del name
     pool = Pool(pool="test_pool", slots=3)
     session.add(pool)
     session.commit()
     response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 400
     self.assertEqual
     (
         {
             "detail": error_detail,
             "status": 400,
             "title": "Bad Request",
             "type": "about:blank",
         },
         response.json,
     )
 def test_response_400(self, error_detail, request_json, session):
     pool = Pool(pool="test_pool", slots=2)
     session.add(pool)
     session.commit()
     response = self.client.patch("api/v1/pools/test_pool",
                                  json=request_json,
                                  environ_overrides={'REMOTE_USER': "******"})
     assert response.status_code == 400
     self.assertEqual(
         {
             "detail": error_detail,
             "status": 400,
             "title": "Bad Request",
             "type": EXCEPTIONS_LINK_MAP[400],
         },
         response.json,
     )
Beispiel #29
0
 def test_response_200(self, url, patch_json, expected_name, expected_slots,
                       session):
     pool = Pool(pool="test_pool", slots=3)
     session.add(pool)
     session.commit()
     response = self.client.patch(url, json=patch_json)
     assert response.status_code == 200
     self.assertEqual(
         {
             "name": expected_name,
             "slots": expected_slots,
             "occupied_slots": 0,
             "running_slots": 0,
             "queued_slots": 0,
             "open_slots": expected_slots,
         },
         response.json,
     )
Beispiel #30
0
 def test_response_409(self, session):
     pool_name = "test_pool_a"
     pool_instance = Pool(pool=pool_name, slots=3)
     session.add(pool_instance)
     session.commit()
     response = self.client.post(
         "api/v1/pools",
         json={"name": "test_pool_a", "slots": 3},
         environ_overrides={'REMOTE_USER': "******"}
     )
     assert response.status_code == 409
     self.assertEqual(
         {
             "detail": f"Pool: {pool_name} already exists",
             "status": 409,
             "title": "Object already exists",
             "type": "about:blank",
         },
         response.json,
     )
Beispiel #31
0
def post_pool(session):
    """Create a pool"""
    required_fields = ["name", "slots"
                       ]  # Pool would require both fields in the post request
    for field in required_fields:
        if field not in request.json.keys():
            raise BadRequest(detail=f"'{field}' is a required property")

    try:
        post_body = pool_schema.load(request.json, session=session)
    except ValidationError as err:
        raise BadRequest(detail=str(err.messages))

    pool = Pool(**post_body)
    try:
        session.add(pool)
        session.commit()
        return pool_schema.dump(pool)
    except IntegrityError:
        raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")