def setUp(self): app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL") app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False db = SQLAlchemy() db.init_app(app) app.app_context().push() app.config['TESTING'] = True app.config['WTF_CSRF_ENABLED'] = False app.config['DEBUG'] = False if not os.getenv("DATABASE_URL"): raise RuntimeError("DATABASE_URL is not set") self.app = app.test_client()
def parse_node(self, response, node): node.remove_namespaces() cds_bibrec, ok, errs = create_bibrec( node.xpath('.//record').extract()[0] ) if not ok: raise RuntimeError("Cannot parse record %s: %s", node, errs) self.logger.info("Here's the record: %s" % cds_bibrec) inspire_bibrec = CDS2Inspire(cds_bibrec).get_record() marcxml_record = record_xml_output(inspire_bibrec) record = create_record(marcxml_record) app = Flask('hepcrawl') app.config.update( self.settings.getdict('MARC_TO_HEP_SETTINGS', {}) ) with app.app_context(): json_record = hep.do(record) base_uri = self.settings['SCHEMA_BASE_URI'] json_record['$schema'] = base_uri + 'hep.json' parsed_item = ParsedItem( record=json_record, record_format='hep', ) return parsed_item
def init_app(name=None): name = name or __name__ app = Flask(name) app.config.from_pyfile('application.cfg') app.config.from_pyfile('production.cfg') app.jinja_env.globals['GIT_HASH'] = get_git_hash() #app.ldap_orm = Connection(app.config['LDAP_URL'], app.config['LDAP_BIND_DN'], app.config['LDAP_BIND_PW'], auto_bind=True) server = Server(app.config['LDAP_URL'], get_info=ALL) app.ldap_conn = Connection(server, app.config['LDAP_BIND_DN'], app.config['LDAP_BIND_PW'], auto_bind=True) model.ldap_conn = app.ldap_conn model.base_dn = app.config['LDAP_BASE_DN'] from .model import db db.init_app(app) with app.app_context(): db.create_all() app.babel = Babel(app) init_oauth2(app) app.login_manager = LoginManager(app) #init hydra admin api hydra_config = hydra.Configuration( host=app.config['HYDRA_ADMIN_URL'], username=app.config['HYDRA_ADMIN_USER'], password=app.config['HYDRA_ADMIN_PASSWORD']) hydra_client = hydra.ApiClient(hydra_config) app.hydra_api = hydra.AdminApi(hydra_client) from .views import auth_views, frontend_views, init_login_manager, api_views, pki_views, admin_views init_login_manager(app) app.register_blueprint(auth_views) app.register_blueprint(frontend_views) app.register_blueprint(api_views) app.register_blueprint(pki_views) app.register_blueprint(admin_views) @app.before_request def befor_request(): request_start_time = time.time() g.request_time = lambda: "%.5fs" % (time.time() - request_start_time) from .translations import init_babel init_babel(app) app.lenticular_services = {} for service_name, service_config in app.config[ 'LENTICULAR_CLOUD_SERVICES'].items(): app.lenticular_services[service_name] = model.Service.from_config( service_name, service_config) app.pki = Pki(app.config['PKI_PATH'], app.config['DOMAIN']) return app
def _create_json_record(xml_record): object_record = create_record(etree.XML(xml_record)) app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) with app.app_context(): dojson_record = hep.do(object_record) return dojson_record
def init_app(app: Flask): from src.api import models db.init_app(app) print(emojize('Base de dados conectada :outbox_tray:')) with app.app_context(): db.create_all() app.db = db
def train_model(current_app: Flask, model_id: int, job_id: int, epochs: int, from_checkpoint: bool): with current_app.app_context(): reader = DatasetReader(current_app.config['TRAINING_IMAGES'], current_app.config['TRAINING_LABELS'], current_app.config['TESTING_IMAGES'], current_app.config['TESTING_LABELS']) nn = NeuralNetwork(reader, current_app.config) nn.load(model_id, job_id=job_id, from_checkpoint=from_checkpoint) nn.train(model_id, epochs)
def create_app(config_name: str = "default") -> Flask: app = Flask(__name__) init_config(app, config_name) with app.app_context(): init_extensions(app) init_blueprints(app) init_commands(app) return app
def _create_json_record(xml_record): object_record = create_record(etree.XML(xml_record)) app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) with app.app_context(): dojson_record = hep.do(object_record) base_uri = self.settings['SCHEMA_BASE_URI'] dojson_record['$schema'] = base_uri + 'hep.json' return dojson_record
def _handle_recurring_scheduler_job(job_type: str, interval: int, handle_func: Callable, app: Flask) -> None: try: with app.app_context(): with TwoPhaseExecutor(db.session) as tpe: _HandleRecurringSchedulerJob(tpe).transaction( job_type, interval, handle_func, app) except sqlalchemy.exc.IntegrityError: logger.debug(f"SchedulerJob with type {job_type} already exists.") except Exception: logger.error(f"Failed to run job with type: {job_type}.")
def _parsed_items_from_marcxml( self, marcxml_records, base_url="", hostname="", url_schema=None, ftp_params=None, url="" ): app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) file_name = url.split('/')[-1] with app.app_context(): parsed_items = [] for xml_record in marcxml_records: try: record = marcxml2record(xml_record) parsed_item = ParsedItem(record=record, record_format='hep') parsed_item.ftp_params = ftp_params parsed_item.file_name = file_name files_to_download = [ self._get_full_uri( current_url=document['url'], base_url=base_url, schema=url_schema, hostname=hostname, ) for document in parsed_item.record.get('documents', []) if self._has_to_be_downloaded(document['url']) ] parsed_item.file_urls = files_to_download self.logger.info('Got the following attached documents to download: %s'% files_to_download) self.logger.info('Got item: %s' % parsed_item) parsed_items.append(parsed_item) except Exception as e: tb = ''.join(traceback.format_tb(sys.exc_info()[2])) error_parsed_item = ParsedItem.from_exception( record_format='hep', exception=repr(e), traceback=tb, source_data=xml_record, file_name=file_name ) parsed_items.append(error_parsed_item) return parsed_items
def _get_crawl_result(xml_record): app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) with app.app_context(): item = ParsedItem(record={}, record_format='hep') try: item.record = marcxml2record(xml_record) except Exception as e: item.exception = repr(e) item.traceback = traceback.format_tb(sys.exc_info()[2]) item.source_data = xml_record return item
def _parsed_item_from_marcxml(marcxml_record, settings): app = Flask('hepcrawl') app.config.update(settings.getdict('MARC_TO_HEP_SETTINGS', {})) with app.app_context(): try: record = cds_marcxml2record(marcxml_record) return ParsedItem(record=record, record_format='hep') except Exception as e: tb = ''.join(traceback.format_tb(sys.exc_info()[2])) return ParsedItem.from_exception(record_format='hep', exception=repr(e), traceback=tb, source_data=marcxml_record)
def _parsed_items_from_marcxml(self, marcxml_records, base_url="", url=""): self.logger.info('parsing record') app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) file_name = url.split('/')[-1].split("?")[0] with app.app_context(): parsed_items = [] for xml_record in marcxml_records: try: record = marcxml2record(xml_record) parsed_item = ParsedItem(record=record, record_format='hep') parsed_item.file_name = file_name new_documents = [] files_to_download = [] self.logger.info("Parsed document: %s", parsed_item.record) self.logger.info("Record have documents: %s", "documents" in parsed_item.record) for document in parsed_item.record.get('documents', []): if self._is_local_path(document['url']): document['url'] = self._get_full_uri( document['url']) self.logger.info("Updating document %s", document) else: files_to_download.append(document['url']) new_documents.append(document) if new_documents: parsed_item.record['documents'] = new_documents parsed_item.file_urls = files_to_download self.logger.info( 'Got the following attached documents to download: %s', files_to_download) self.logger.info('Got item: %s', parsed_item) parsed_items.append(parsed_item) except Exception as e: tb = ''.join(traceback.format_tb(sys.exc_info()[2])) error_parsed_item = ParsedItem.from_exception( record_format='hep', exception=repr(e), traceback=tb, source_data=xml_record, file_name=file_name) parsed_items.append(error_parsed_item) return parsed_items
def test_create_superuser(runner: FlaskCliRunner, app: Flask, username: str, email: str, password: str, output: str): result = runner.invoke(args=[ "createsuperuser", "--username", username, "--email", email, "--password", password ]) assert output in result.output if output != "Created": return with app.app_context(): user = User.query.filter_by(username=username).first() assert user assert user.stuff assert user.superuser
def stuff_client(app: Flask): client = app.test_client() user = User("stuff", "*****@*****.**", stuff=True) user.set_password("stuffpsw") with app.app_context(): db.session.add(user) db.session.commit() res = client.post("/auth/token/login", json={ "username": "******", "password": "******", }) assert res.is_json token = res.get_json().get("token") client.environ_base["HTTP_AUTHORIZATION"] = "Bearer {}".format(token) return client
def app_factory(config): '''This factory creates a Flask application instance based on the settings in the provided configuration object.''' # Create the Flask app, register the blueprint and initialize the # flask-mail service. # Blueprints must be used to implement factories (I believe) because # they allow the factory to register the application's routes # before they must be implemented. app = Flask(__name__) app.config.from_object(config) from app.views import web app.register_blueprint(web) mail.init_app(app) # Create the (only) mongodb instance for use by all running applications. # Different apps may use different Mongo databases. # The production server already has its data, so don't always # call db_reset(). mongo = PyMongo(app) if config.DATA: with app.app_context(): db_reset(mongo, config.DATA) # Store the Mongo database object in the Flask globals so that it can # be accessed when needed. @app.before_request def before_request(): g.mongo = mongo # This Jinja2 template must be defined here, on the app, rather than # in views.py on the blueprint. @app.template_filter('start_token') def start_token(name): '''This routine returns the substring of the given name up to but not including the first slash. If there is no slash, it returns the full name. It is used in the templates to find either the page name or category.''' if (name.find('/') == -1): return name else: return name[:name.find('/')] return app
def create_app(config_class=Config): app = Flask(__name__) app.config.from_object(Config) with app.app_context(): db.init_app(app) login_manager.init_app(app) sess.init_app(app) from . import models from .routes import main app.register_blueprint(main) # from app.users.routes import users # from app.quiz.routes import quiz # from app.home.routes import home # app.register_blueprint(users) # app.register_blueprint(quiz) # app.register_blueprint(home) return app
def __new__(cls, *args, **kwargs): if cls._instance is None: app = Flask(__name__) app.register_blueprint(PeopleBluePrintFactory.create()) flask_injector = FlaskInjector( app=app, modules=[DatabaseModule(), ], ) app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:////tmp/production.db' db.init_app(app) with app.app_context(): db.create_all() cls._instance = flask_injector return cls._instance
def parse_node(self, response, node): node.remove_namespaces() cds_bibrec, ok, errs = create_bibrec( node.xpath('.//record').extract()[0]) if not ok: raise RuntimeError("Cannot parse record %s: %s", node, errs) self.logger.info("Here's the record: %s" % cds_bibrec) inspire_bibrec = CDS2Inspire(cds_bibrec).get_record() marcxml_record = record_xml_output(inspire_bibrec) record = create_record(marcxml_record) app = Flask('hepcrawl') app.config.update(self.settings.getdict('MARC_TO_HEP_SETTINGS', {})) with app.app_context(): json_record = hep.do(record) base_uri = self.settings['SCHEMA_BASE_URI'] json_record['$schema'] = base_uri + 'hep.json' parsed_item = ParsedItem( record=json_record, record_format='hep', ) return parsed_item
def create_app(config_name): app = Flask(__name__, instance_relative_config=True) try: app.config.from_object(app_config[config_name]) except KeyError: RED = '\033[31m' RESET = '\033[0m' print '{}FLASK_CONFIG not found in environment variable. Define as dev, stage or prod{}'.format( RED, RESET) sys.exit() app.config.from_pyfile('app_config.py') try: with app.app_context(): for module in initiables: module.init_app(app) return app except KeyError as kerr: RED = '\033[31m' RESET = '\033[0m' print '{}{} not found in config{}'.format(RED, kerr.message, RESET) sys.exit()
def run(): """ daemon run function. This function should be called to start the system. """ instance_path = ini_config.get("Flask", "INSTANCE_PATH") # app: Flask application object logging.debug("initializing the Flask app") global globalFlaskApp globalFlaskApp = Flask(__name__, instance_path=instance_path, instance_relative_config=True) is_debug = ini_config.getboolean("Flask", "DEBUG") is_testing = ini_config.getboolean("Flask", "TESTING") is_json_sort_keys = ini_config.getboolean("Flask", "JSON_SORT_KEYS") max_content_length = ini_config.getint("Flask", "MAX_CONTENT_LENGTH") globalFlaskApp.config.update(DEBUG=is_debug, TESTING=is_testing, JSON_SORT_KEYS=is_json_sort_keys, MAX_CONTENT_LENGTH=max_content_length) with globalFlaskApp.app_context(): logging.info("Starting application ...") from rgapps.utils.utility import get_log_file_handles logger_fds = get_log_file_handles(logging.getLogger()) logging.debug("Logger file handles fileno [{0}]" .format(logger_fds)) system = platform.system() if system == "Linux": logging.info("Server running on Linux.") pid_file = ini_config.get("Sensor", "SENSOR_PID_FILE") working_dir = ini_config.get("Logging", "WORKING_DIR") logging.debug("Instantiating daemon with pid_file [{0}] " "and working_dir [{1}]" .format(pid_file, working_dir)) import daemon.pidfile daemon_context = daemon.DaemonContext( working_directory=working_dir, umask=0o002, pidfile=daemon.pidfile.PIDLockFile(pid_file)) logging.debug("Setting up daemon signal map") daemon_context.signal_map = { signal.SIGTERM: program_cleanup } logging.debug("daemon signal map has been setup") if (logger_fds): logging.debug("setting files_preserve for the log file " "descriptor [{0}]" .format(logger_fds)) daemon_context.files_preserve = logger_fds logging.debug("Starting daemon by opening its context.") daemon_context.open() logging.info("Calling read_store_readings....") read_store_readings() logging.debug("Closing the daemon context.") daemon_context.close() else: logging.info("Server running on Windows system ...") read_store_readings() return
from flask.app import Flask from flask_bcrypt import Bcrypt from models import create_db from resources import init_resources app = Flask(__name__) app.app_context().push() flask_bcrypt = Bcrypt() def create_app() -> Flask: app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.sqlite3' flask_bcrypt.init_app(app) create_db(app) init_resources(app) return app
class SimpleFlaskAppTest(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.client = self.app.test_client() self.db = MongoEngine(self.app) with self.app.app_context(): self.db.connection.drop_database("test") # self.db.connection class TestCol(db.Document): value = db.StringField() def __unicode__(self): return "TestCol(value={})".format(self.value) TestCol.objects.delete() TestCol.objects.create(value="1") TestCol.objects.create(value="2") self.TestCol = TestCol def _parse(self, resp): resp = resp.decode("utf-8") return json.loads(resp) def test_validation_mongoengine_will_work_with_model_serializer(self): class Doc(db.Document): value = db.StringField(validation=RegexpValidator(r"\d+", message="Bad value").for_mongoengine()) Doc.drop_collection() class Serializer(ModelSerializer): class Meta: model = Doc Doc.objects.create(value="123") s = Serializer(data={"value": "asd"}) self.assertEqual(s.validate(), False) self.assertEqual(s.errors, {"value": ["Bad value"]}) def test_resource_decorator(self): class S(BaseSerializer): field = fields.StringField(required=True) @self.app.route("/test", methods=["POST"]) @validate(S) def resource(cleaned_data): return "OK" resp = self.client.post("/test", data=json.dumps({}), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 400) self.assertEqual( json.loads(resp.data.decode("utf-8")), {'field': ['Field is required']} ) def testSimpleResourceAndRouter(self): router = DefaultRouter(self.app) class Resource(BaseResource): def get(self, request): return "GET" def post(self, request): return "POST" def put(self, request): return "PUT" def patch(self, request): return "PATCH" def delete(self, request): return "DELETE" @list_route(methods=["GET", "POST"]) def listroute(self, request): return "LIST" @detail_route(methods=["GET", "POST"]) def detailroute(self, request, pk): return "detail" self.assertSetEqual( set(Resource.get_allowed_methods()), {"get", "post", "put", "patch", "delete"} ) router.register("/test", Resource, "test") for method in ["get", "post", "put", "patch", "delete"]: resp = getattr(self.client, method)("/test") self.assertEqual(resp.data.decode("utf-8"), method.upper()) for method in ["GET", "POST"]: resp = getattr(self.client, method.lower())("/test/listroute") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.data.decode("utf-8"), "LIST") for method in ["GET", "POST"]: resp = getattr(self.client, method.lower())("/test/detailroute/1") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.data.decode("utf-8"), "detail") resp = self.client.get("/test/detailroute") self.assertEqual(resp.status_code, 404) def testRoutingWithBluePrint(self): bp = Blueprint("test", __name__) router = DefaultRouter(bp) class Res(BaseResource): def get(self, request): return "GET" router.register("/blabla", Res, "blabla") self.app.register_blueprint(bp, url_prefix="/test") with self.app.test_request_context(): self.assertEqual(url_for("test.blabla"), "/test/blabla") @pytest.mark.testModelResource123 def testModelResource(self): router = DefaultRouter(self.app) class Base(db.Document): title = db.StringField() class ED(db.EmbeddedDocument): value = db.StringField() class Model(db.Document): base = db.ReferenceField(Base) f1 = db.StringField() f2 = db.BooleanField() f3 = db.StringField() embedded = db.EmbeddedDocumentField(ED) listf = db.EmbeddedDocumentListField(ED) dictf = db.DictField() Model.objects.delete() ins = Model.objects.create( base=Base.objects.create(title="1"), f1="1", f2=True, f3="1", embedded={"value": "123"}, listf=[{"value": "234"}], dictf={"key": "value"} ) Model.objects.create( base=Base.objects.create(title="2"), f1="2", f2=True, f3="2", embedded={"value": "123"}, listf=[{"value": "234"}] ) class S(ModelSerializer): title = fields.ForeignKeyField("base__title") class Meta: model = Model fk_fields = ("base__title", ) class ModelRes(ModelResource): serializer_class = S queryset = Model.objects.all() pagination_class = DefaultPagination router.register("/test", ModelRes, "modelres") resp = self.client.get("/test") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) item = data["results"][0] self.assertEqual(item["dictf"], {"key": "value"}) self.assertEqual(item["title"], "1") self.assertEqual(item["base__title"], "1") # get one object resp = self.client.get("/test/{}".format(ins.id)) self.assertEqual(resp.status_code, 200) pprint(self._parse(resp.data)) #test pagination for i in range(10): Model.objects.create( base=Base.objects.create(title="1"), f1="1", f2=True, f3="2" ) self.assertEqual(Model.objects.count(), 12) resp = self.client.get("/test?page=1") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) results = data["results"] self.assertEqual(results[0]["embedded"], {"value": "123"}) self.assertEqual(results[0]["listf"], [{"value": "234"}]) self.assertEqual(len(data["results"]), 10) resp = self.client.get("/test?page=2") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) resp = self.client.get("/test?page=2&page_size=5") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 5) resp = self.client.get("/test?page=3&page_size=5") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) #test put resp = self.client.put("/test/{}".format(ins.id), data=json.dumps({ "f3": "OLALA" }), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 200, resp.data) data = self._parse(resp.data) self.assertEqual(data["f1"], "1") self.assertEqual(data["f2"], True) self.assertEqual(data["f3"], "OLALA") def testMongoEngineForeignKeyField(self): self.assertEqual(self.TestCol.objects.count(), 2) class Serializer(BaseSerializer): fk = fields.MongoEngineIdField(self.TestCol, required=True) v = Serializer({"fk": "123"}) self.assertEqual(v.validate(), False) self.assertEqual(v.errors, {'fk': ['Incorrect id: 123']}) v = Serializer({"fk": str(self.TestCol.objects.first().id)}) v.validate() self.assertEqual(v.errors, {}) self.assertEqual(v.cleaned_data["fk"], self.TestCol.objects.first()) class R(BaseResource): def post(self, request): errors, data = self.validate_request(Serializer) if errors: return errors return "OK" self.app.add_url_rule("/api", view_func=R.as_view("test2"), methods=["GET", "POST"]) resp = self.client.post("/api", data=json.dumps({}), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 400) data = json.loads(resp.data.decode("utf-8")) self.assertEqual(data["fk"], ['Field is required']) def testSerialization(self): class Col(db.Document): value = db.StringField() created = db.DateTimeField(default=datetime.datetime.now) Col.objects.delete() Col.objects.create(value="1") Col.objects.create(value="2") class S(BaseSerializer): value = fields.StringField() created = fields.DateTimeField(read_only=True) data = S(Col.objects.all()).to_python() self.assertEqual(len(data), 2) self.assertEqual( list(map(lambda i: i["value"], data)), ["1", "2"] ) #test can't set read only field ser = S({"value": "1", "created": "2016-01-01 00:00:00"}) ser.validate() self.assertTrue("created" not in ser.cleaned_data) def testModelSerialization(self): class DeepInner(db.EmbeddedDocument): value = db.StringField() class Inner(db.EmbeddedDocument): value = db.StringField() deep = db.EmbeddedDocumentField(DeepInner) class Col(db.Document): value = db.StringField() excluded_field = db.StringField(default="excluded") created = db.DateTimeField(default=datetime.datetime.now) inner = db.EmbeddedDocumentField(Inner) Col.objects.delete() Col.objects.create(value="1", inner={"value": "inner1", "deep": {"value": "123"}}) Col.objects.create(value="2", inner={"value": "inner2"}) class Serializer(ModelSerializer): method_field = fields.MethodField("test") renamed = fields.ForeignKeyField(document_fieldname="inner__deep__value") def test(self, doc): return doc.value class Meta: model = Col fields = ("value", "created", "method_field") fk_fields = ("inner__value", "inner__deep__value") data = Serializer(Col.objects.all()).to_python() for item in data: self.assertTrue("value" in item) self.assertEqual(item["value"], item["method_field"]) self.assertTrue(type(item["created"]), datetime.datetime) self.assertEqual(item["renamed"], item["inner__deep__value"])
class BaseTestCase(TestCase): def setUp(self): self.app = Flask(__name__) self.test_client = self.app.test_client() self.init_logging() self.init_validator_context() self.config = TEST_CONFIG self.auth_cookie = None load_filters(self.app.jinja_env, self.config) self.app_context = self.app.app_context() self.app_context.__enter__() set_template_loader(self.app.jinja_env) init_configuration(self.app, self.config) init_blueprints(self.app) init_services(self.app) init_login_system(self.app) init_db(self.app) init_plugins() self.mailer = celery.conf['MAILER'] self.mailer.mails = [] self.sms_sender = celery.conf['SMS_SENDER'] self.sms_sender.sms = [] self.user = None self.user_profile = None UserManager.init(self.config, self.app.logger) sql_db.init_app(self.app) sql_db.create_all() for table in reversed(sql_db.metadata.sorted_tables): sql_db.engine.execute(table.delete()) @self.app.errorhandler(413) def catcher(error): data_json = json.dumps({"error": {"code": errors.FileToLarge.ERROR_CODE, "message": errors.FileToLarge.ERROR_MESSAGE}}) result = Response(data_json, mimetype='application/json', status=400) result.headers.add('Access-Control-Allow-Credentials', "true") result.headers.add('Access-Control-Allow-Origin', "http://%s" % self.config['site_domain']) return result def tearDown(self): sql_db.session.close() #sql_db.drop_all() for table in reversed(sql_db.metadata.sorted_tables): sql_db.engine.execute(table.delete()) # noinspection PyUnresolvedReferences self.app.model_cache_context.clear() self.app_context.__exit__(None, None, None) def get_test_resource_name(self, name): return os.path.join(CURRENT_DIR, 'test_data', name) def init_logging(self): consoleHandler = logging.StreamHandler() consoleHandler.setFormatter( logging.Formatter('%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]')) consoleHandler.setLevel(logging.DEBUG) self.app.logger.addHandler(consoleHandler) self.app.logger.setLevel(logging.DEBUG) def init_validator_context(self): self.app.validator_context = ValidatorContext() self.app.rendering_context = RenderingContext() self.app.model_cache_context = ModelCacheContext()
class BaseTestCase(TestCase): def setUp(self): self.app = Flask(__name__) self.test_client = self.app.test_client() self.init_logging() self.init_validator_context() self.config = TEST_CONFIG self.auth_cookie = None load_filters(self.app.jinja_env, self.config) self.app_context = self.app.app_context() self.app_context.__enter__() set_template_loader(self.app.jinja_env) init_configuration(self.app, self.config) init_blueprints(self.app) init_services(self.app) init_login_system(self.app) init_db(self.app) init_plugins() self.mailer = celery.conf['MAILER'] self.mailer.mails = [] self.sms_sender = celery.conf['SMS_SENDER'] self.sms_sender.sms = [] self.user = None self.user_profile = None UserManager.init(self.config, self.app.logger) sql_db.init_app(self.app) sql_db.create_all() for table in reversed(sql_db.metadata.sorted_tables): sql_db.engine.execute(table.delete()) @self.app.errorhandler(413) def catcher(error): data_json = json.dumps({ "error": { "code": errors.FileToLarge.ERROR_CODE, "message": errors.FileToLarge.ERROR_MESSAGE } }) result = Response(data_json, mimetype='application/json', status=400) result.headers.add('Access-Control-Allow-Credentials', "true") result.headers.add('Access-Control-Allow-Origin', "http://%s" % self.config['site_domain']) return result def tearDown(self): sql_db.session.close() #sql_db.drop_all() for table in reversed(sql_db.metadata.sorted_tables): sql_db.engine.execute(table.delete()) # noinspection PyUnresolvedReferences self.app.model_cache_context.clear() self.app_context.__exit__(None, None, None) def get_test_resource_name(self, name): return os.path.join(CURRENT_DIR, 'test_data', name) def init_logging(self): consoleHandler = logging.StreamHandler() consoleHandler.setFormatter( logging.Formatter( '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]' )) consoleHandler.setLevel(logging.DEBUG) self.app.logger.addHandler(consoleHandler) self.app.logger.setLevel(logging.DEBUG) def init_validator_context(self): self.app.validator_context = ValidatorContext() self.app.rendering_context = RenderingContext() self.app.model_cache_context = ModelCacheContext()
def app(): app = Flask(__name__) with app.app_context(): yield app
from flask.app import Flask from flask_cors import CORS from blog.models.exts import db from blog.models.exts import bcrypt from blog.models.modetool import creat_db from blog.urls.main import init_url config = 'conf.flask.config.ProductionConfig' #config = 'conf.flask.config.DevelopmentConfig' app = Flask(__name__, static_folder="./web/static", template_folder="./web") CORS(app) app.config.from_object(config) db.init_app(app) bcrypt.init_app(app) with app.app_context(): creat_db() init_url(app) if __name__ == '__main__': app.run(port=8000)
class SimpleFlaskAppTest(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.client = self.app.test_client() self.db = MongoEngine(self.app) with self.app.app_context(): self.db.connection.drop_database("test") # self.db.connection class TestCol(db.Document): value = db.StringField() def __unicode__(self): return "TestCol(value={})".format(self.value) TestCol.objects.delete() TestCol.objects.create(value="1") TestCol.objects.create(value="2") self.TestCol = TestCol def _parse(self, resp): resp = resp.decode("utf-8") return json.loads(resp) def test_validation_mongoengine_will_work_with_model_serializer(self): class Doc(db.Document): value = db.StringField(validation=RegexpValidator( r"\d+", message="Bad value").for_mongoengine()) Doc.drop_collection() class Serializer(ModelSerializer): class Meta: model = Doc Doc.objects.create(value="123") s = Serializer(data={"value": "asd"}) self.assertEqual(s.validate(), False) self.assertEqual(s.errors, {"value": ["Bad value"]}) def test_resource_decorator(self): class S(BaseSerializer): field = fields.StringField(required=True) @self.app.route("/test", methods=["POST"]) @validate(S) def resource(cleaned_data): return "OK" resp = self.client.post("/test", data=json.dumps({}), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 400) self.assertEqual(json.loads(resp.data.decode("utf-8")), {'field': ['Field is required']}) def testSimpleResourceAndRouter(self): router = DefaultRouter(self.app) class Resource(BaseResource): def get(self, request): return "GET" def post(self, request): return "POST" def put(self, request): return "PUT" def patch(self, request): return "PATCH" def delete(self, request): return "DELETE" @list_route(methods=["GET", "POST"]) def listroute(self, request): return "LIST" @detail_route(methods=["GET", "POST"]) def detailroute(self, request, pk): return "detail" self.assertSetEqual(set(Resource.get_allowed_methods()), {"get", "post", "put", "patch", "delete"}) router.register("/test", Resource, "test") for method in ["get", "post", "put", "patch", "delete"]: resp = getattr(self.client, method)("/test") self.assertEqual(resp.data.decode("utf-8"), method.upper()) for method in ["GET", "POST"]: resp = getattr(self.client, method.lower())("/test/listroute") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.data.decode("utf-8"), "LIST") for method in ["GET", "POST"]: resp = getattr(self.client, method.lower())("/test/detailroute/1") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.data.decode("utf-8"), "detail") resp = self.client.get("/test/detailroute") self.assertEqual(resp.status_code, 404) def testRoutingWithBluePrint(self): bp = Blueprint("test", __name__) router = DefaultRouter(bp) class Res(BaseResource): def get(self, request): return "GET" router.register("/blabla", Res, "blabla") self.app.register_blueprint(bp, url_prefix="/test") with self.app.test_request_context(): self.assertEqual(url_for("test.blabla"), "/test/blabla") @pytest.mark.testModelResource123 def testModelResource(self): router = DefaultRouter(self.app) class Base(db.Document): title = db.StringField() class ED(db.EmbeddedDocument): value = db.StringField() class Model(db.Document): base = db.ReferenceField(Base) f1 = db.StringField() f2 = db.BooleanField() f3 = db.StringField() embedded = db.EmbeddedDocumentField(ED) listf = db.EmbeddedDocumentListField(ED) dictf = db.DictField() Model.objects.delete() ins = Model.objects.create(base=Base.objects.create(title="1"), f1="1", f2=True, f3="1", embedded={"value": "123"}, listf=[{ "value": "234" }], dictf={"key": "value"}) Model.objects.create(base=Base.objects.create(title="2"), f1="2", f2=True, f3="2", embedded={"value": "123"}, listf=[{ "value": "234" }]) class S(ModelSerializer): title = fields.ForeignKeyField("base__title") class Meta: model = Model fk_fields = ("base__title", ) class ModelRes(ModelResource): serializer_class = S queryset = Model.objects.all() pagination_class = DefaultPagination router.register("/test", ModelRes, "modelres") resp = self.client.get("/test") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) item = data["results"][0] self.assertEqual(item["dictf"], {"key": "value"}) self.assertEqual(item["title"], "1") self.assertEqual(item["base__title"], "1") # get one object resp = self.client.get("/test/{}".format(ins.id)) self.assertEqual(resp.status_code, 200) pprint(self._parse(resp.data)) #test pagination for i in range(10): Model.objects.create(base=Base.objects.create(title="1"), f1="1", f2=True, f3="2") self.assertEqual(Model.objects.count(), 12) resp = self.client.get("/test?page=1") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) results = data["results"] self.assertEqual(results[0]["embedded"], {"value": "123"}) self.assertEqual(results[0]["listf"], [{"value": "234"}]) self.assertEqual(len(data["results"]), 10) resp = self.client.get("/test?page=2") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) resp = self.client.get("/test?page=2&page_size=5") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 5) resp = self.client.get("/test?page=3&page_size=5") self.assertEqual(resp.status_code, 200) data = self._parse(resp.data) self.assertEqual(len(data["results"]), 2) #test put resp = self.client.put("/test/{}".format(ins.id), data=json.dumps({"f3": "OLALA"}), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 200, resp.data) data = self._parse(resp.data) self.assertEqual(data["f1"], "1") self.assertEqual(data["f2"], True) self.assertEqual(data["f3"], "OLALA") def testMongoEngineForeignKeyField(self): self.assertEqual(self.TestCol.objects.count(), 2) class Serializer(BaseSerializer): fk = fields.MongoEngineIdField(self.TestCol, required=True) v = Serializer({"fk": "123"}) self.assertEqual(v.validate(), False) self.assertEqual(v.errors, {'fk': ['Incorrect id: 123']}) v = Serializer({"fk": str(self.TestCol.objects.first().id)}) v.validate() self.assertEqual(v.errors, {}) self.assertEqual(v.cleaned_data["fk"], self.TestCol.objects.first()) class R(BaseResource): def post(self, request): errors, data = self.validate_request(Serializer) if errors: return errors return "OK" self.app.add_url_rule("/api", view_func=R.as_view("test2"), methods=["GET", "POST"]) resp = self.client.post("/api", data=json.dumps({}), headers={"Content-Type": "application/json"}) self.assertEqual(resp.status_code, 400) data = json.loads(resp.data.decode("utf-8")) self.assertEqual(data["fk"], ['Field is required']) def testSerialization(self): class Col(db.Document): value = db.StringField() created = db.DateTimeField(default=datetime.datetime.now) Col.objects.delete() Col.objects.create(value="1") Col.objects.create(value="2") class S(BaseSerializer): value = fields.StringField() created = fields.DateTimeField(read_only=True) data = S(Col.objects.all()).to_python() self.assertEqual(len(data), 2) self.assertEqual(list(map(lambda i: i["value"], data)), ["1", "2"]) #test can't set read only field ser = S({"value": "1", "created": "2016-01-01 00:00:00"}) ser.validate() self.assertTrue("created" not in ser.cleaned_data) def testModelSerialization(self): class DeepInner(db.EmbeddedDocument): value = db.StringField() class Inner(db.EmbeddedDocument): value = db.StringField() deep = db.EmbeddedDocumentField(DeepInner) class Col(db.Document): value = db.StringField() excluded_field = db.StringField(default="excluded") created = db.DateTimeField(default=datetime.datetime.now) inner = db.EmbeddedDocumentField(Inner) Col.objects.delete() Col.objects.create(value="1", inner={ "value": "inner1", "deep": { "value": "123" } }) Col.objects.create(value="2", inner={"value": "inner2"}) class Serializer(ModelSerializer): method_field = fields.MethodField("test") renamed = fields.ForeignKeyField( document_fieldname="inner__deep__value") def test(self, doc): return doc.value class Meta: model = Col fields = ("value", "created", "method_field") fk_fields = ("inner__value", "inner__deep__value") data = Serializer(Col.objects.all()).to_python() for item in data: self.assertTrue("value" in item) self.assertEqual(item["value"], item["method_field"]) self.assertTrue(type(item["created"]), datetime.datetime) self.assertEqual(item["renamed"], item["inner__deep__value"])
if os.environ['SERVER_SOFTWARE'].startswith('Dev'): return constants.ENV_LOCAL elif os.environ['SERVER_SOFTWARE'].startswith('Google App Engine/'): #For considering an environment staging we assume the version id # contains -staging and the URL current_version_id = str(os.environ['CURRENT_VERSION_ID']) if ( 'CURRENT_VERSION_ID') in os.environ else '' if '-staging' in current_version_id: return constants.ENV_STAGING #If not local or staging then is production TODO: really? return constants.ENV_PRODUCTION return constants.ENV_LOCAL flask_app = Flask(__name__) flask_app.json_encoder = CustomJSONEncoder with flask_app.app_context(): environment = get_environment() #Load settings from the corresponding class if environment == constants.ENV_PRODUCTION: flask_app.config.from_object(ProductionConfig) else: flask_app.config.from_object(TestingConfig) #If debug then enable if flask_app.config['DEBUG']: flask_app.debug = True app = DebuggedApplication(flask_app, evalex=True) app = flask_app from google.appengine.ext.deferred import application as deferred_app import admin_views import views