def test_serialize(self): pool_model_a = Pool(pool="test_pool_a", slots=3) pool_model_b = Pool(pool="test_pool_b", slots=3) instance = PoolCollection(pools=[pool_model_a, pool_model_b], total_entries=2) assert { "pools": [ { "name": "test_pool_a", "slots": 3, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 3, }, { "name": "test_pool_b", "slots": 3, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 3, }, ], "total_entries": 2, } == pool_collection_schema.dump(instance)
def test_default_pool_open_slots(self): set_default_pool_slots(5) assert 5 == Pool.get_default_pool().open_slots() dag = DAG( dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag) op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(ti1) session.add(ti2) session.commit() session.close() assert 2 == Pool.get_default_pool().open_slots() assert { "default_pool": { "open": 2, "queued": 2, "total": 5, "running": 1, } } == Pool.slots_stats()
def add_default_pool_if_not_exists(session=None): from airflow.models.pool import Pool if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): default_pool = Pool( pool=Pool.DEFAULT_POOL_NAME, slots=conf.getint(section='core', key='non_pooled_task_slot_count', fallback=128), description="Default pool", ) session.add(default_pool) session.commit()
def add_default_pool_if_not_exists(session=None): """ Add default pool if it does not exist. """ if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session): default_pool = Pool( pool=Pool.DEFAULT_POOL_NAME, slots=conf.getint(section='core', key='non_pooled_task_slot_count', fallback=128), description="Default pool", ) session.add(default_pool) session.commit()
def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3) session.add(pool_model) session.commit() result = session.query(Pool).all() assert len(result) == 2 # accounts for the default pool as well response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 self.assertEqual( { "pools": [ { "name": "default_pool", "slots": 128, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 128, }, { "name": "test_pool_a", "slots": 3, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 3, }, ], "total_entries": 2, }, response.json, )
def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION): """ Validates and raise exception if any task in a dag is using a non-existent pool :meta private: """ from airflow.models.pool import Pool def check_pools(dag): task_pools = {task.pool for task in dag.tasks} nonexistent_pools = task_pools - pools if nonexistent_pools: return ( f"Dag '{dag.dag_id}' references non-existent pools: {list(sorted(nonexistent_pools))!r}" ) pools = {p.pool for p in Pool.get_pools(session)} for dag in dagbag.dags.values(): message = check_pools(dag) if message: self.dag_warnings.add( DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message)) for subdag in dag.subdags: message = check_pools(subdag) if message: self.dag_warnings.add( DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message))
def test_should_respect_page_size_limit_default(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)] session.add_all(pools) session.commit() result = session.query(Pool).count() self.assertEqual(result, 121) response = self.client.get("/api/v1/pools", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 self.assertEqual(len(response.json['pools']), 100)
def test_should_return_conf_max_if_req_max_above_conf(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 200)] session.add_all(pools) session.commit() result = session.query(Pool).count() self.assertEqual(result, 200) response = self.client.get("/api/v1/pools?limit=180", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 self.assertEqual(len(response.json['pools']), 150)
def test_should_raises_401_unauthenticated(self, session): pool = Pool(pool="test_pool", slots=2) session.add(pool) session.commit() response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, ) assert_401(response)
def test_limit_and_offset(self, url, expected_pool_ids, session): pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)] session.add_all(pools) session.commit() result = session.query(Pool).count() self.assertEqual(result, 121) # accounts for default pool as well response = self.client.get(url, environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 pool_ids = [pool["name"] for pool in response.json["pools"]] self.assertEqual(pool_ids, expected_pool_ids)
def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=t1, execution_date=DEFAULT_DATE) ti2 = TI(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(3, pool.open_slots())
def test_response_204(self, session): pool_name = "test_pool" pool_instance = Pool(pool=pool_name, slots=3) session.add(pool_instance) session.commit() response = self.client.delete(f"api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 204 # Check if the pool is deleted from the db response = self.client.get(f"api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"}) self.assertEqual(response.status_code, 404)
def test_default_pool_open_slots(self): set_default_pool_slots(5) self.assertEqual(5, Pool.get_default_pool().open_slots()) dag = DAG( dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag) op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(2, Pool.get_default_pool().open_slots())
def test_should_raise_400_for_invalid_orderby(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(1, 121)] session.add_all(pools) session.commit() result = session.query(Pool).count() assert result == 121 response = self.client.get( "/api/v1/pools?order_by=open_slots", environ_overrides={'REMOTE_USER': "******"} ) assert response.status_code == 400 msg = "Ordering with 'open_slots' is disallowed or the attribute does not exist on the model" assert response.json['detail'] == msg
def test_should_raises_401_unauthenticated(self, session): pool_name = "test_pool" pool_instance = Pool(pool=pool_name, slots=3) session.add(pool_instance) session.commit() response = self.client.delete(f"api/v1/pools/{pool_name}") assert_401(response) # Should still exists response = self.client.get(f"/api/v1/pools/{pool_name}", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200
def test_response_400(self, name, error_detail, url, patch_json, session): del name pool = Pool(pool="test_pool", slots=3) session.add(pool) session.commit() response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 400 assert { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], } == response.json
def setUp(self): clear_db_pools() self.pools = [Pool.get_default_pool()] for i in range(self.USER_POOL_COUNT): name = f'experimental_{i + 1}' pool = models.Pool( pool=name, slots=i, description=name, ) self.pools.append(pool) with create_session() as session: session.add_all(self.pools)
def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3) session.add(pool_model) session.commit() response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 assert { "name": "test_pool_a", "slots": 3, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 3, } == response.json
def test_response_200(self, url, patch_json, expected_name, expected_slots, session): pool = Pool(pool="test_pool", slots=3) session.add(pool) session.commit() response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 assert { "name": expected_name, "slots": expected_slots, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": expected_slots, } == response.json
def setUp(self): self.session = settings.Session() clear_db_pools() self.pools = [Pool.get_default_pool()] for i in range(self.USER_POOL_COUNT): name = 'experimental_%s' % (i + 1) pool = models.Pool( pool=name, slots=i, description=name, ) self.session.add(pool) self.pools.append(pool) self.session.commit()
def test_serialize(self, session): pool_model = Pool(pool="test_pool", slots=2) session.add(pool_model) session.commit() pool_instance = session.query(Pool).filter( Pool.pool == pool_model.pool).first() serialized_pool = pool_schema.dump(pool_instance) assert serialized_pool == { "name": "test_pool", "slots": 2, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 2, }
def create_pool(self, name, slots, description): if not (name and name.strip()): raise AirflowBadRequest("Pool name shouldn't be empty") pool_name_length = Pool.pool.property.columns[0].type.length if len(name) > pool_name_length: raise AirflowBadRequest( f"pool name cannot be more than {pool_name_length} characters") try: slots = int(slots) except ValueError: raise AirflowBadRequest(f"Bad value for `slots`: {slots}") pool = Pool.create_or_update_pool(name=name, slots=slots, description=description) return pool.pool, pool.slots, pool.description
def test_response_400(self, error_detail, request_json, session): pool = Pool(pool="test_pool", slots=2) session.add(pool) session.commit() response = self.client.patch("api/v1/pools/test_pool", json=request_json) assert response.status_code == 400 self.assertEqual( { "detail": error_detail, "status": 400, "title": "Bad request", "type": "about:blank", }, response.json, )
def test_response_409(self, session): pool_name = "test_pool_a" pool_instance = Pool(pool=pool_name, slots=3) session.add(pool_instance) session.commit() response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3}, environ_overrides={'REMOTE_USER': "******"}, ) assert response.status_code == 409 assert { "detail": f"Pool: {pool_name} already exists", "status": 409, "title": "Conflict", "type": EXCEPTIONS_LINK_MAP[409], } == response.json
def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3) session.add(pool_model) session.commit() response = self.client.get("/api/v1/pools/test_pool_a") assert response.status_code == 200 self.assertEqual( { "name": "test_pool_a", "slots": 3, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": 3, }, response.json, )
def test_response_400(self, name, error_detail, url, patch_json, session): del name pool = Pool(pool="test_pool", slots=3) session.add(pool) session.commit() response = self.client.patch(url, json=patch_json, environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 400 self.assertEqual ( { "detail": error_detail, "status": 400, "title": "Bad Request", "type": "about:blank", }, response.json, )
def test_response_400(self, error_detail, request_json, session): pool = Pool(pool="test_pool", slots=2) session.add(pool) session.commit() response = self.client.patch("api/v1/pools/test_pool", json=request_json, environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 400 self.assertEqual( { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], }, response.json, )
def test_response_200(self, url, patch_json, expected_name, expected_slots, session): pool = Pool(pool="test_pool", slots=3) session.add(pool) session.commit() response = self.client.patch(url, json=patch_json) assert response.status_code == 200 self.assertEqual( { "name": expected_name, "slots": expected_slots, "occupied_slots": 0, "running_slots": 0, "queued_slots": 0, "open_slots": expected_slots, }, response.json, )
def test_response_409(self, session): pool_name = "test_pool_a" pool_instance = Pool(pool=pool_name, slots=3) session.add(pool_instance) session.commit() response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3}, environ_overrides={'REMOTE_USER': "******"} ) assert response.status_code == 409 self.assertEqual( { "detail": f"Pool: {pool_name} already exists", "status": 409, "title": "Object already exists", "type": "about:blank", }, response.json, )
def post_pool(session): """Create a pool""" required_fields = ["name", "slots" ] # Pool would require both fields in the post request for field in required_fields: if field not in request.json.keys(): raise BadRequest(detail=f"'{field}' is a required property") try: post_body = pool_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest(detail=str(err.messages)) pool = Pool(**post_body) try: session.add(pool) session.commit() return pool_schema.dump(pool) except IntegrityError: raise AlreadyExists(detail=f"Pool: {post_body['pool']} already exists")