def get(self, dataset_id): with db.session_scope() as session: dataset = session.query(Dataset).get(dataset_id) if dataset is None: raise NotFoundException( f'Failed to find dataset: {dataset_id}') return {'data': dataset.to_dict()}
def delete(self, model_id): with db_handler.session_scope() as session: model = ModelService(session).drop(model_id) if not model: raise NotFoundException( f'Failed to find model: {model_id}') return {'data': model.to_dict()}, HTTPStatus.OK
def test_is_local(self): with db_handler.session_scope() as session: workflow, job = add_fake_workflow(session) self.assertTrue(workflow.is_local()) config = workflow.get_config() config.job_definitions[0].is_federated = True workflow.set_config(config) self.assertFalse(False, workflow.is_local())
def get(self, model_id): detail_level = request.args.get('detail_level', '') with db_handler.session_scope() as session: model_json = ModelService(session).query(model_id, detail_level) if not model_json: raise NotFoundException( f'Failed to find model: {model_id}') return {'data': model_json}, HTTPStatus.OK
def put(self, model_id): with db_handler.session_scope() as session: model = session.query(Model).filter_by(id=model_id).one_or_none() if not model: raise NotFoundException( f'Failed to find model: {model_id}') model.extra = request.args.get('extra', model.extra) session.commit() return {'data': model.to_dict()}, HTTPStatus.OK
def get(self, dataset_id: int): if dataset_id <= 0: raise NotFoundException(f'Failed to find dataset: {dataset_id}') name = request.args.get('name', None) if not name: raise InvalidArgumentException(f'required params name') with db.session_scope() as session: data = DatasetService(session).feature_metrics(name, dataset_id) return {'data': data}
def post(self): group = ModelGroup() group.name = request.args.get('name', group.name) group.extra = request.args.get('extra', group.extra) with db_handler.session_scope() as session: session.add(group) session.commit() return {'data': group.to_dict()}, HTTPStatus.OK
def get(self): parser = reqparse.RequestParser() parser.add_argument('project', type=int, required=False, help='project') data = parser.parse_args() with db.session_scope() as session: datasets = DatasetService(session).get_datasets( project_id=int(data['project'] or 0)) return {'data': [d.to_dict() for d in datasets]}
def get(self): detail_level = request.args.get('detail_level', '') # TODO serialized query may incur performance penalty with db_handler.session_scope() as session: model_list = [ ModelService(session).query(m.id, detail_level) for m in Model.query.filter( Model.type.in_([ ModelType.NN_MODEL.value, ModelType.TREE_MODEL.value ])).all() ] return {'data': model_list}, HTTPStatus.OK
def patch(self, group_id): group = ModelGroup.query.filter_by(id=group_id).one_or_none() if not group: raise NotFoundException( f'Failed to find group: {group_id}') group.name = request.args.get('name', group.name) group.extra = request.args.get('extra', group.extra) with db_handler.session_scope() as session: session.add(group) session.commit() return {'data': group.to_dict()}, HTTPStatus.OK
def test_patch_create_job_flags(self): with db_handler.session_scope() as session: workflow, job = add_fake_workflow(session) response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'create_job_flags': [3]}) self.assertEqual(response.status_code, HTTPStatus.OK) patched_job = Job.query.get(job.id) self.assertEqual(patched_job.is_disabled, True) response = self.patch_helper(f'/api/v2/workflows/{workflow.id}', data={'create_job_flags': [1]}) self.assertEqual(response.status_code, HTTPStatus.OK) patched_job = Job.query.get(job.id) self.assertEqual(patched_job.is_disabled, False)
def setUp(self): super().setUp() with db.session_scope() as session: self.default_dataset1 = Dataset( name='default dataset1', dataset_type=DatasetType.STREAMING, comment='test comment1', path='/data/dataset/123', project_id=1, ) session.add(self.default_dataset1) session.commit() time.sleep(1) with db.session_scope() as session: self.default_dataset2 = Dataset( name='default dataset2', dataset_type=DatasetType.STREAMING, comment='test comment2', path=os.path.join(tempfile.gettempdir(), 'dataset/123'), project_id=2, ) session.add(self.default_dataset2) session.commit()
def post(self, dataset_id: int): parser = reqparse.RequestParser() parser.add_argument('event_time', type=int) parser.add_argument('files', required=True, type=list, location='json', help=_FORMAT_ERROR_MESSAGE.format('files')) parser.add_argument('move', type=bool) parser.add_argument('comment', type=str) body = parser.parse_args() event_time = body.get('event_time') files = body.get('files') move = body.get('move', False) comment = body.get('comment') with db.session_scope() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() if dataset is None: raise NotFoundException( f'Failed to find dataset: {dataset_id}') if event_time is None and dataset.type == DatasetType.STREAMING: raise InvalidArgumentException( details='data_batch.event_time is empty') # TODO: PSI dataset should not allow multi batches # Use current timestamp to fill when type is PSI event_time = datetime.fromtimestamp( event_time or datetime.utcnow().timestamp(), tz=timezone.utc) batch_folder_name = event_time.strftime('%Y%m%d_%H%M%S') batch_path = f'{dataset.path}/batch/{batch_folder_name}' # Create batch batch = DataBatch(dataset_id=dataset.id, event_time=event_time, comment=comment, state=BatchState.NEW, move=move, path=batch_path) batch_details = dataset_pb2.DataBatch() for file_path in files: file = batch_details.files.add() file.source_path = file_path file_name = file_path.split('/')[-1] file.destination_path = f'{batch_path}/{file_name}' batch.set_details(batch_details) session.add(batch) session.commit() session.refresh(batch) scheduler.wakeup(data_batch_ids=[batch.id]) return {'data': batch.to_dict()}
def test_get_all_users(self): deleted_user = User(username='******', email='*****@*****.**', state=State.DELETED) with db.session_scope() as session: session.add(deleted_user) session.commit() resp = self.get_helper('/api/v2/auth/users') self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED) self.signin_as_admin() resp = self.get_helper('/api/v2/auth/users') self.assertEqual(resp.status_code, HTTPStatus.OK) self.assertEqual(len(self.get_response_data(resp)), 2)
def initial_db(): with db.session_scope() as session: # initial user info first for u_info in INITIAL_USER_INFO: username = u_info['username'] password = u_info['password'] name = u_info['name'] email = u_info['email'] role = u_info['role'] state = u_info['state'] if session.query(User).filter_by( username=username).first() is None: user = User(username=username, name=name, email=email, role=role, state=state) user.set_password(password=password) session.add(user) session.commit()
def post(self): parser = reqparse.RequestParser() parser.add_argument('name', required=True, type=str, help=_FORMAT_ERROR_MESSAGE.format('name')) parser.add_argument('dataset_type', required=True, type=DatasetType, help=_FORMAT_ERROR_MESSAGE.format('dataset_type')) parser.add_argument('comment', type=str) parser.add_argument('project_id', required=True, type=int, help=_FORMAT_ERROR_MESSAGE.format('project_id')) body = parser.parse_args() name = body.get('name') dataset_type = body.get('dataset_type') comment = body.get('comment') project_id = body.get('project_id') with db.session_scope() as session: try: # Create dataset dataset = Dataset( name=name, dataset_type=dataset_type, comment=comment, path=_get_dataset_path(name), project_id=project_id, ) session.add(dataset) # TODO: scan cronjob session.commit() return {'data': dataset.to_dict()} except Exception as e: session.rollback() raise InvalidArgumentException(details=str(e))
def patch(self, dataset_id: int): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=False, help='dataset name') parser.add_argument('comment', type=str, required=False, help='dataset comment') parser.add_argument('comment') data = parser.parse_args() with db.session_scope() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise NotFoundException( f'Failed to find dataset: {dataset_id}') if data['name']: dataset.name = data['name'] if data['comment']: dataset.comment = data['comment'] session.commit() return {'data': dataset.to_dict()}, HTTPStatus.OK
def get(self, dataset_id: int): if dataset_id <= 0: raise NotFoundException(f'Failed to find dataset: {dataset_id}') with db.session_scope() as session: data = DatasetService(session).get_dataset_preview(dataset_id) return {'data': data}