Beispiel #1
0
def register_extensions(app):
    """Register Flask extensions."""
    db.app = app
    db.init_app(app)
    csrf = CSRFProtect()
    csrf.init_app(app)
    cache.init_app(app)
    login_manager.init_app(app)
    celery_app.init_app(app)
Beispiel #2
0
def create_app(config='config.Production', instance=True):

    app = Flask(__name__, instance_relative_config=instance)

    app.config.from_object(config)
    app.config.from_pyfile('config.py', silent=True)
    app.config.from_envvar('FLASKR_SETTINGS', silent=True)
    app.debug = app.config['DEBUG']
    csrf = CSRFProtect()
    csrf.init_app(app)
    register_api(app)

    db.init_app(app)
    user_datastore = get_datastore(db)
    app.security = Security(app, user_datastore)

    @app.before_first_request
    def init_database():
        """Check if the database doesn't exist and create it"""
        try:
            app.logger.info('Getting users...')
            get_users()
            app.logger.info('Success.')
        except:
            app.logger.error('Fail. We are gonna create the database')
            first_data(app, user_datastore)

        if not os.path.isfile(os.getenv("HOME") + '/.ssh/keys/id_rsa'):
            ssh.generate_key()

    @app.route('/', methods=['GET'])
    @login_required
    @roles_required('user')
    def index():
        return render_template('dashboard.html')

    @app.before_request
    def log_request():
        app.logger.debug(request)

    return app
Beispiel #3
0
def create_app():
    app = Flask(__name__)
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    from docassemble.base.config import daconfig
    import docassemble.webapp.database
    import docassemble.webapp.db_object
    connect_string = docassemble.webapp.database.connection_string()
    alchemy_connect_string = docassemble.webapp.database.alchemy_connection_string()
    app.config['SQLALCHEMY_DATABASE_URI'] = alchemy_connect_string
    app.secret_key = daconfig.get('secretkey', '38ihfiFehfoU34mcq_4clirglw3g4o87')
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    db = docassemble.webapp.db_object.init_flask()
    db.init_app(app)
    csrf = CSRFProtect()
    csrf.init_app(app)
    babel = Babel()
    babel.init_app(app)
    if daconfig.get('behind https load balancer', False):
        if proxyfix_version >= 15:
            app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1)
        else:
            app.wsgi_app = ProxyFix(app.wsgi_app)
    return app, csrf, babel
Beispiel #4
0
import random
import string
from flask_wtf.csrf import CSRFProtect

csrf = CSRFProtect()

db = SQLAlchemy()
# After defining `db`, import auth models due to
# circular dependency.
from mhn.auth.models import User, Role, ApiKey
user_datastore = SQLAlchemyUserDatastore(db, User, Role)


mhn = Flask(__name__)
mhn.config.from_object('config')
csrf.init_app(mhn)

# Email app setup.
mail = Mail()
mail.init_app(mhn)

# Registering app on db instance.
db.init_app(mhn)

# Setup flask-security for auth.
Security(mhn, user_datastore)

# Registering blueprints.
from mhn.api.views import api
mhn.register_blueprint(api)
Beispiel #5
0
# Import flask and template operators
from flask import Flask, render_template

# Import Peewee
from config import DATABASE
from peewee import SqliteDatabase

from flask_login import LoginManager

# Define the WSGI application object
letstalk = Flask(__name__)

# CSRF
from flask_wtf.csrf import CSRFProtect
csrf = CSRFProtect()
csrf.init_app(letstalk)

# Configurations
letstalk.config.from_object('config')

# Define the database object which is imported
# by modules and controllers
db = SqliteDatabase(DATABASE)

login_m = LoginManager()
login_m.init_app(letstalk)

from app.controllers import views

from app.models.tables import *
db.create_tables([User, Post])
Beispiel #6
0
@app.route('/healthcheck')
def healthcheck():
    return 'ok'


### Flask Mail ###
from flask_mail import Mail
mail = Mail(app=app)
from security_monkey.common.utils import send_email as common_send_email


### Flask-WTF CSRF Protection ###
from flask_wtf.csrf import CSRFProtect, CSRFError

csrf = CSRFProtect()
csrf.init_app(app)


@app.errorhandler(CSRFError)
def csrf_error(reason):
    app.logger.debug("CSRF ERROR: {}".format(reason))
    return render_template('csrf_error.json', reason=reason), 400


from security_monkey.datastore import User, Role

### Flask-Security ###
from flask_security.core import Security
from flask_security.datastore import SQLAlchemyUserDatastore
user_datastore = SQLAlchemyUserDatastore(db, User, Role)
security = Security(app, user_datastore)
Beispiel #7
0
from flask import Flask, send_from_directory, redirect
from .views import (auth, admin, user)
from .models import db
from flask_wtf.csrf import CSRFProtect, generate_csrf

instance = Flask(__name__,
                 instance_relative_config=True,
                 template_folder='templates')
csrf = CSRFProtect()
csrf.init_app(instance)

instance.config.from_pyfile('config.py')

instance.register_blueprint(auth.bp)
instance.register_blueprint(admin.bp)
instance.register_blueprint(user.bp)


@instance.before_request
def before_request():
    db.init(instance.config.get('DB_NAME'),
            host=instance.config.get('DB_HOST'),
            user=instance.config.get('DB_USER'),
            password=instance.config.get('DB_PASS'))
    db.connect(reuse_if_open=True)


@instance.after_request
def after_request(response):
    if not db.is_closed():
        db.close()
Beispiel #8
0
    def run(self):
        app = Flask(__name__)
        app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
        csrf = CSRFProtect()
        import os
        SECRET_KEY = os.urandom(32)
        app.config['SECRET_KEY'] = SECRET_KEY
        csrf.init_app(app)

        @app.route('/')
        def index():
            not_configured = self.check_configured()
            if not_configured:
                return not_configured
            events = self.control.get_events()
            message = request.args.get('message', None)
            message_type = request.args.get('message_type', None)
            return render_template('events.html',
                                   client=self.name,
                                   state=ReducerStateToString(
                                       self.control.state()),
                                   events=events,
                                   logs=None,
                                   refresh=True,
                                   configured=True,
                                   message=message,
                                   message_type=message_type)

        # http://localhost:8090/add?name=combiner&address=combiner&port=12080&token=e9a3cb4c5eaff546eec33ff68a7fbe232b68a192
        @app.route('/status')
        def status():
            return {'state': ReducerStateToString(self.control.state())}

        @app.route('/netgraph')
        def netgraph():

            result = {'nodes': [], 'edges': []}

            result['nodes'].append({
                "id": "r0",
                "label": "Reducer",
                "x": -1.2,
                "y": 0,
                "size": 25,
                "type": 'reducer',
            })
            x = 0
            y = 0
            count = 0
            meta = {}
            combiner_info = []
            for combiner in self.control.network.get_combiners():
                try:
                    report = combiner.report()
                    combiner_info.append(report)
                except:
                    pass
            y = y + 0.5
            width = 5
            if len(combiner_info) < 1:
                return result
            step = 5 / len(combiner_info)
            x = -width / 3.0
            for combiner in combiner_info:
                print("combiner info {}".format(combiner_info), flush=True)

                try:
                    result['nodes'].append({
                        "id":
                        combiner['name'],  # "n{}".format(count),
                        "label":
                        "Combiner ({} clients)".format(
                            combiner['nr_active_clients']),
                        "x":
                        x,
                        "y":
                        y,
                        "size":
                        15,
                        "name":
                        combiner['name'],
                        "type":
                        'combiner',
                        # "color":'blue',
                    })
                except Exception as err:
                    print(err)

                x = x + step
                count = count + 1
            y = y + 0.25

            count = 0
            width = 5
            step = 5 / len(combiner_info)
            x = -width / 2.0
            # for combiner in self.control.statestore.list_clients():
            for combiner in combiner_info:
                for a in range(0, int(combiner['nr_active_clients'])):
                    # y = y + 0.25
                    try:
                        result['nodes'].append({
                            "id": "c{}".format(count),
                            "label": "Client",
                            "x": x,
                            "y": y,
                            "size": 15,
                            "name": "c{}".format(count),
                            "combiner": combiner['name'],
                            "type": 'client',
                            # "color":'blue',
                        })
                    except Exception as err:
                        print(err)
                    # print("combiner prefferred name {}".format(client['combiner']), flush=True)
                    x = x + 0.25
                    count = count + 1

            count = 0
            for node in result['nodes']:
                try:
                    if node['type'] == 'combiner':
                        result['edges'].append({
                            "id": "e{}".format(count),
                            "source": node['id'],
                            "target": 'r0',
                        })
                    elif node['type'] == 'client':
                        result['edges'].append({
                            "id": "e{}".format(count),
                            "source": node['combiner'],
                            "target": node['id'],
                        })
                except Exception as e:
                    pass
                count = count + 1

            return result

        @app.route('/events')
        def events():
            import json
            from bson import json_util

            json_docs = []
            for doc in self.control.get_events():
                json_doc = json.dumps(doc, default=json_util.default)
                json_docs.append(json_doc)

            json_docs.reverse()
            return {'events': json_docs}

        @app.route('/add')
        def add():
            """ Add a combiner to the network. """
            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})

            # TODO check for get variables
            name = request.args.get('name', None)
            address = str(request.args.get('address', None))
            port = request.args.get('port', None)
            # token = request.args.get('token')
            # TODO do validation

            if port is None or address is None or name is None:
                return "Please specify correct parameters."

            # Try to retrieve combiner from db
            combiner = self.control.network.get_combiner(name)
            if not combiner:
                # Create a new combiner
                import base64
                certificate, key = self.certificate_manager.get_or_create(
                    address).get_keypair_raw()
                cert_b64 = base64.b64encode(certificate)
                key_b64 = base64.b64encode(key)

                # TODO append and redirect to index.
                import copy
                combiner = CombinerInterface(self, name, address, port,
                                             copy.deepcopy(certificate),
                                             copy.deepcopy(key),
                                             request.remote_addr)
                self.control.network.add_combiner(combiner)

            combiner = self.control.network.get_combiner(name)

            ret = {
                'status': 'added',
                'certificate': combiner['certificate'],
                'key': combiner['key'],
                'storage': self.control.statestore.get_storage_backend(),
                'statestore': self.control.statestore.get_config(),
            }

            return jsonify(ret)

        @app.route('/eula', methods=['GET', 'POST'])
        def eula():
            for r in request.headers:
                print("header contains: {}".format(r), flush=True)

            return render_template('eula.html', configured=True)

        @app.route('/models', methods=['GET', 'POST'])
        def models():

            if request.method == 'POST':
                # upload seed file
                uploaded_seed = request.files['seed']
                if uploaded_seed:
                    from io import BytesIO
                    a = BytesIO()
                    a.seek(0, 0)
                    uploaded_seed.seek(0)
                    a.write(uploaded_seed.read())
                    helper = self.control.get_helper()
                    model = helper.load_model_from_BytesIO(a.getbuffer())
                    self.control.commit(uploaded_seed.filename, model)
            else:
                not_configured = self.check_configured()
                if not_configured:
                    return not_configured
                h_latest_model_id = self.control.get_latest_model()

                model_info = self.control.get_model_info()
                return render_template('models.html',
                                       h_latest_model_id=h_latest_model_id,
                                       seed=True,
                                       model_info=model_info,
                                       configured=True)

            seed = True
            return redirect(url_for('models', seed=seed))

        @app.route('/delete_model_trail', methods=['GET', 'POST'])
        def delete_model_trail():
            if request.method == 'POST':
                from fedn.common.tracer.mongotracer import MongoTracer
                statestore_config = self.control.statestore.get_config()
                self.tracer = MongoTracer(statestore_config['mongo_config'],
                                          statestore_config['network_id'])
                try:
                    self.control.drop_models()
                except:
                    pass

                # drop objects in minio
                self.control.delete_bucket_objects()
                return redirect(url_for('models'))
            seed = True
            return redirect(url_for('models', seed=seed))

        @app.route('/drop_control', methods=['GET', 'POST'])
        def drop_control():
            if request.method == 'POST':
                self.control.statestore.drop_control()
                return redirect(url_for('control'))
            return redirect(url_for('control'))

        # http://localhost:8090/control?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548
        @app.route('/control', methods=['GET', 'POST'])
        def control():
            not_configured = self.check_configured()
            if not_configured:
                return not_configured
            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = True
            try:
                self.current_compute_context = self.control.get_compute_context(
                )
            except:
                self.current_compute_context = None

            if self.current_compute_context == None or self.current_compute_context == '':
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=False,
                    message=
                    'No compute context is set. Please set one here <a href="/context">/context</a>'
                )

            if self.control.state() == ReducerState.setup:
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=refresh,
                    message=
                    'Warning. Reducer is not base-configured. please do so with config file.'
                )

            if self.control.state() == ReducerState.monitoring:
                return redirect(
                    url_for('index',
                            state=state,
                            refresh=refresh,
                            message="Reducer is in monitoring state"))

            if request.method == 'POST':
                timeout = float(request.form.get('timeout', 180))
                rounds = int(request.form.get('rounds', 1))
                task = (request.form.get('task', ''))
                clients_required = request.form.get('clients_required', 1)
                clients_requested = request.form.get('clients_requested', 8)

                # checking if there are enough clients connected to start!
                clients_available = 0
                try:
                    for combiner in self.control.network.get_combiners():
                        if combiner.allowing_clients():
                            combiner_state = combiner.report()
                            nac = combiner_state['nr_active_clients']

                            clients_available = clients_available + int(nac)
                except Exception as e:
                    pass

                if clients_available < clients_required:
                    return redirect(
                        url_for(
                            'index',
                            state=state,
                            message=
                            "Not enough clients available to start rounds.",
                            message_type='warning'))

                validate = request.form.get('validate', False)
                if validate == 'False':
                    validate = False
                helper_type = request.form.get('helper', 'keras')
                # self.control.statestore.set_framework(helper_type)

                latest_model_id = self.control.get_latest_model()

                config = {
                    'round_timeout': timeout,
                    'model_id': latest_model_id,
                    'rounds': rounds,
                    'clients_required': clients_required,
                    'clients_requested': clients_requested,
                    'task': task,
                    'validate': validate,
                    'helper_type': helper_type
                }

                import threading
                threading.Thread(target=self.control.instruct,
                                 args=(config, )).start()
                # self.control.instruct(config)
                return redirect(
                    url_for('index',
                            state=state,
                            refresh=refresh,
                            message="Sent execution plan.",
                            message_type='SUCCESS'))

            else:
                seed_model_id = None
                latest_model_id = None
                try:
                    seed_model_id = self.control.get_first_model()[0]
                    latest_model_id = self.control.get_latest_model()
                except Exception as e:
                    pass

                return render_template(
                    'index.html',
                    latest_model_id=latest_model_id,
                    compute_package=self.current_compute_context,
                    seed_model_id=seed_model_id,
                    helper=self.control.statestore.get_framework(),
                    validate=True,
                    configured=True)

            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = False
            return render_template('index.html',
                                   client=client,
                                   state=state,
                                   logs=logs,
                                   refresh=refresh,
                                   configured=True)

        @app.route('/assign')
        def assign():
            """Handle client assignment requests. """

            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})

            name = request.args.get('name', None)
            combiner_preferred = request.args.get('combiner', None)

            if combiner_preferred:
                combiner = self.control.find(combiner_preferred)
            else:
                combiner = self.control.find_available_combiner()

            if combiner is None:
                return jsonify({'status': 'retry'})
            ## Check that a framework has been selected prior to assigning clients.
            framework = self.control.statestore.get_framework()
            if not framework:
                return jsonify({'status': 'retry'})

            client = {
                'name': name,
                'combiner_preferred': combiner_preferred,
                'combiner': combiner.name,
                'ip': request.remote_addr,
                'status': 'available'
            }
            self.control.network.add_client(client)

            if combiner:
                import base64
                cert_b64 = base64.b64encode(combiner.certificate)
                response = {
                    'status': 'assigned',
                    'host': combiner.address,
                    'ip': combiner.ip,
                    'port': combiner.port,
                    'certificate': str(cert_b64).split('\'')[1],
                    'model_type': self.control.statestore.get_framework()
                }

                return jsonify(response)
            elif combiner is None:
                return jsonify({'status': 'retry'})

            return jsonify({'status': 'retry'})

        @app.route('/infer')
        def infer():
            if self.control.state() == ReducerState.setup:
                return "Error, not configured"
            result = ""
            try:
                self.control.set_model_id()
            except fedn.exceptions.ModelError:
                print("Failed to seed control.")

            return result

        def combiner_stats():
            combiner_info = []
            for combiner in self.control.network.get_combiners():
                try:
                    report = combiner.report()
                    combiner_info.append(report)
                except:
                    pass
                return combiner_info
            return False

        def create_map():
            cities_dict = {
                'city': [],
                'lat': [],
                'lon': [],
                'country': [],
                'name': [],
                'role': [],
                'size': []
            }

            from fedn import get_data
            dbpath = get_data('geolite2/GeoLite2-City.mmdb')

            with geoip2.database.Reader(dbpath) as reader:
                for combiner in self.control.statestore.list_combiners():
                    try:
                        response = reader.city(combiner['ip'])
                        cities_dict['city'].append(response.city.name)

                        r = 1.0  # Rougly 100km
                        w = r * math.sqrt(numpy.random.random())
                        t = 2.0 * math.pi * numpy.random.random()
                        x = w * math.cos(t)
                        y = w * math.sin(t)
                        lat = str(float(response.location.latitude) + x)
                        lon = str(float(response.location.longitude) + y)
                        cities_dict['lat'].append(lat)
                        cities_dict['lon'].append(lon)

                        cities_dict['country'].append(
                            response.country.iso_code)

                        cities_dict['name'].append(combiner['name'])
                        cities_dict['role'].append('Combiner')
                        cities_dict['size'].append(10)

                    except geoip2.errors.AddressNotFoundError as err:
                        print(err)

            with geoip2.database.Reader(dbpath) as reader:
                for client in self.control.statestore.list_clients():
                    try:
                        response = reader.city(client['ip'])
                        cities_dict['city'].append(response.city.name)
                        cities_dict['lat'].append(response.location.latitude)
                        cities_dict['lon'].append(response.location.longitude)
                        cities_dict['country'].append(
                            response.country.iso_code)

                        cities_dict['name'].append(client['name'])
                        cities_dict['role'].append('Client')
                        # TODO: Optionally relate to data size
                        cities_dict['size'].append(6)

                    except geoip2.errors.AddressNotFoundError as err:
                        print(err)

            config = self.control.statestore.get_config()

            cities_df = pd.DataFrame(cities_dict)
            if cities_df.empty:
                return False
            fig = px.scatter_geo(cities_df,
                                 lon="lon",
                                 lat="lat",
                                 projection="natural earth",
                                 color="role",
                                 size="size",
                                 hover_name="city",
                                 hover_data={
                                     "city": False,
                                     "lon": False,
                                     "lat": False,
                                     'size': False,
                                     'name': True,
                                     'role': True
                                 })

            fig.update_geos(fitbounds="locations", showcountries=True)
            fig.update_layout(
                title="FEDn network: {}".format(config['network_id']))

            fig = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
            return fig

        @app.route('/metric_type', methods=['GET', 'POST'])
        def change_features():
            feature = request.args['selected']
            plot = Plot(self.control.statestore)
            graphJSON = plot.create_box_plot(feature)
            return graphJSON

        @app.route('/dashboard')
        def dashboard():
            not_configured = self.check_configured()
            if not_configured:
                return not_configured

            plot = Plot(self.control.statestore)
            try:
                valid_metrics = plot.fetch_valid_metrics()
                box_plot = plot.create_box_plot(valid_metrics[0])
            except Exception as e:
                valid_metrics = None
                box_plot = None
                print(e, flush=True)
            table_plot = plot.create_table_plot()
            # timeline_plot = plot.create_timeline_plot()
            timeline_plot = None
            clients_plot = plot.create_client_plot()
            return render_template('dashboard.html',
                                   show_plot=True,
                                   box_plot=box_plot,
                                   table_plot=table_plot,
                                   timeline_plot=timeline_plot,
                                   clients_plot=clients_plot,
                                   metrics=valid_metrics,
                                   configured=True)

        @app.route('/network')
        def network():
            not_configured = self.check_configured()
            if not_configured:
                return not_configured
            plot = Plot(self.control.statestore)
            round_time_plot = plot.create_round_plot()
            mem_cpu_plot = plot.create_cpu_plot()
            combiners_plot = plot.create_combiner_plot()
            map_plot = create_map()
            combiner_info = combiner_stats()
            return render_template('network.html',
                                   map_plot=map_plot,
                                   network_plot=True,
                                   round_time_plot=round_time_plot,
                                   mem_cpu_plot=mem_cpu_plot,
                                   combiners_plot=combiners_plot,
                                   combiner_info=combiner_info,
                                   configured=True)

        @app.route('/config/download', methods=['GET'])
        def config_download():

            chk_string = ""
            name = self.control.get_compute_context()
            if name is None or name == '':
                chk_string = ''
            else:
                file_path = os.path.join(UPLOAD_FOLDER, name)
                print("trying to get {}".format(file_path))
                from fedn.utils.checksum import md5

                try:
                    sum = str(md5(file_path))
                except FileNotFoundError as e:
                    sum = ''
                chk_string = "checksum: {}".format(sum)

            network_id = self.network_id
            discover_host = self.name
            discover_port = self.port
            token = self.token
            ctx = """network_id: {network_id}
controller:
    discover_host: {discover_host}
    discover_port: {discover_port}
    token: {token}
    {chk_string}""".format(network_id=network_id,
                           discover_host=discover_host,
                           discover_port=discover_port,
                           token=token,
                           chk_string=chk_string)

            from io import BytesIO
            from flask import send_file
            obj = BytesIO()
            obj.write(ctx.encode('UTF-8'))
            obj.seek(0)
            return send_file(obj,
                             as_attachment=True,
                             attachment_filename='client.yaml',
                             mimetype='application/x-yaml')

        @app.route('/context', methods=['GET', 'POST'])
        @csrf.exempt  # TODO fix csrf token to form posting in package.py
        def context():
            # if self.control.state() != ReducerState.setup or self.control.state() != ReducerState.idle:
            #    return "Error, Context already assigned!"
            reset = request.args.get(
                'reset',
                None)  # if reset is not empty then allow context re-set
            if reset:
                return render_template('context.html')

            if request.method == 'POST':

                if 'file' not in request.files:
                    flash('No file part')
                    return redirect(url_for('context'))

                file = request.files['file']
                helper_type = request.form.get('helper', 'keras')
                # if user does not select file, browser also
                # submit an empty part without filename
                if file.filename == '':
                    flash('No selected file')
                    return redirect(url_for('context'))

                if file and allowed_file(file.filename):
                    filename = secure_filename(file.filename)
                    file_path = os.path.join(app.config['UPLOAD_FOLDER'],
                                             filename)
                    file.save(file_path)

                    if self.control.state(
                    ) == ReducerState.instructing or self.control.state(
                    ) == ReducerState.monitoring:
                        return "Not allowed to change context while execution is ongoing."

                    self.control.set_compute_context(filename, file_path)
                    self.control.statestore.set_framework(helper_type)
                    return redirect(url_for('control'))

            from flask import send_from_directory
            name = request.args.get('name', '')

            if name == '':
                name = self.control.get_compute_context()
                if name == None or name == '':
                    return render_template('context.html')

            # There is a potential race condition here, if one client requests a package and at
            # the same time another one triggers a fetch from Minio and writes to disk.
            try:
                mutex = Lock()
                mutex.acquire()
                return send_from_directory(app.config['UPLOAD_FOLDER'],
                                           name,
                                           as_attachment=True)
            except:
                try:
                    data = self.control.get_compute_package(name)
                    file_path = os.path.join(app.config['UPLOAD_FOLDER'], name)
                    with open(file_path, 'wb') as fh:
                        fh.write(data)
                    return send_from_directory(app.config['UPLOAD_FOLDER'],
                                               name,
                                               as_attachment=True)
                except:
                    raise
            finally:
                mutex.release()

            return render_template('context.html')

        @app.route('/checksum', methods=['GET', 'POST'])
        def checksum():

            #sum = ''
            name = request.args.get('name', None)
            if name == '' or name is None:
                name = self.control.get_compute_context()
                if name == None or name == '':
                    return jsonify({})

            file_path = os.path.join(UPLOAD_FOLDER, name)
            print("trying to get {}".format(file_path))
            from fedn.utils.checksum import md5

            try:
                sum = str(md5(file_path))
            except FileNotFoundError as e:
                sum = ''

            data = {'checksum': sum}
            from flask import jsonify
            return jsonify(data)

        if self.certificate:
            print("trying to connect with certs {} and key {}".format(
                str(self.certificate.cert_path),
                str(self.certificate.key_path)),
                  flush=True)
            app.run(host="0.0.0.0",
                    port=self.port,
                    ssl_context=(str(self.certificate.cert_path),
                                 str(self.certificate.key_path)))
Beispiel #9
0
def create_app(config_path=None):
    """
    Creates a Flask app instance and returns it

    Args:
        config_path (str): A path to a JSON object with an app configuration.

    Note:
        Some default configurations can be found in `config`.

    Returns:
        Flask: a Flask app.
    """
    app = Flask(__name__)
    config = OVSConfig(config_path)
    for key, value in config.items():
        app.config[key] = value

    env = app.config['ENV']
    app.config['DEVELOPMENT'] = env == 'DEV'
    app.config['TESTING'] = env == 'TEST'
    app.config['PRODUCTION'] = env == 'PROD'

    with app.app_context():
        db_config = app.config['DATABASE']
        if app.config['SELENIUM']:
            app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:////tmp/ovs.db'
        elif app.config['TESTING']:
            app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://'
        else:
            app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://' + \
                                                    db_config['USER'] + ':' + \
                                                    db_config['PASSWORD'] + '@' + \
                                                    db_config['HOSTNAME'] + ':' + \
                                                    db_config['PORT'] + '/' + \
                                                    db_config['NAME']
        app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

        db.init_app(app)
        if app.config['SELENIUM']:
            from ovs.datagen import DataGen  # Avoid circular dependencies.
            DataGen.clear_db()
        else:
            import ovs.models  # pylint: disable=unused-variable
            db.create_all()

        from ovs.blob import blob
        blob.init_app(app)

        from ovs.models.user_model import bcrypt_app
        bcrypt_app.init_app(app)

        from ovs.utils import serializer
        serializer.init_app(app)

        from ovs.services.auth_service import LOGIN_MANAGER
        LOGIN_MANAGER.init_app(app)
        LOGIN_MANAGER.login_view = '/'
        LOGIN_MANAGER.login_message_category = 'danger'

        from flask_wtf.csrf import CSRFProtect
        csrf = CSRFProtect()
        csrf.init_app(app)

        from ovs import routes
        app.register_blueprint(routes.OvsRoutes, url_prefix='/')
        app.register_blueprint(routes.AdminRoutes, url_prefix='/admin')
        app.register_blueprint(routes.ManagerRoutes, url_prefix='/manager')
        app.register_blueprint(routes.ResidentRoutes, url_prefix='/resident')
        app.register_blueprint(routes.AuthRoutes, url_prefix='/auth')

        app.jinja_env.globals.update(utc_to_timezone=utc_to_timezone)

        if (os.environ.get("WERKZEUG_RUN_MAIN") == "true" or os.environ.get("FLASK_DEBUG") != "True")\
           and not app.config['TESTING']:
            from ovs.datagen import DataGen  # Avoid circular dependencies.
            DataGen.create_defaults()

        db.session.commit()

    return app
Beispiel #10
0
         MAIL_USE_SSL=True,
         MAIL_USERNAME="******",
         MAIL_PASSWORD="******",
         MAIL_DEFAULT_SENDER='default mail sender',
         SQLALCHEMY_COMMIT_ON_TEARDOWN=True))

# use login manager to manage session
login_manager = LoginManager()
login_manager.anonymous_user = Anonymous
login_manager.session_protection = 'strong'
login_manager.login_view = 'index'
login_manager.init_app(app=server)

# csrf protection
csrf = CSRFProtect()
csrf.init_app(server)

#flask dropzone
dropzone = Dropzone()
dropzone.init_app(server)
server.config['DROPZONE_ENABLE_CSRF'] = True
server.config['DROPZONE_ALLOWED_FILE_CUSTOM'] = True
server.config[
    'DROPZONE_ALLOWED_FILE_TYPE'] = '.png, .jpg, .jpeg, .JPG, .pdf, .JPEG'
server.config['DROPZONE_UPLOAD_MULTIPLE'] = False
server.config['DROPZONE_PARALLEL_UPLOADS'] = 1
server.config['DROPZONE_MAX_FILES'] = 10
server.config['DROPZONE_MAX_FILE_SIZE'] = 10


@server.teardown_request
Beispiel #11
0
class Server(Flask):

    status_log_level = {
        200: "info",
        401: "warning",
        403: "warning",
        404: "info",
        500: "error",
    }

    status_error_message = {
        401: "Wrong Credentials",
        403: "Operation not allowed.",
        404: "Invalid POST request.",
        500: "Internal Server Error.",
    }

    def __init__(self):
        static_folder = str(vs.path / "eNMS" / "static")
        super().__init__(__name__, static_folder=static_folder)
        self.update_config()
        self.register_extensions()
        self.configure_login_manager()
        self.configure_context_processor()
        self.configure_errors()
        self.configure_routes()
        self.configure_terminal_socket()

    def update_config(self):
        session_timeout = vs.settings["app"]["session_timeout_minutes"]
        self.config.update({
            "DEBUG":
            vs.settings["app"]["config_mode"].lower() != "production",
            "SECRET_KEY":
            getenv("SECRET_KEY", "secret_key"),
            "WTF_CSRF_TIME_LIMIT":
            None,
            "ERROR_404_HELP":
            False,
            "MAX_CONTENT_LENGTH":
            20 * 1024 * 1024,
            "WTF_CSRF_ENABLED":
            "pytest" not in modules,
            "PERMANENT_SESSION_LIFETIME":
            timedelta(minutes=session_timeout),
        })

    def register_plugins(self):
        for plugin, settings in vs.plugins_settings.items():
            try:
                module = import_module(f"eNMS.plugins.{plugin}")
                module.Plugin(self, controller, db, vs, env, **settings)
            except Exception:
                env.log(
                    "error",
                    f"Could not import plugin '{plugin}':\n{format_exc()}")
                continue
            info(f"Loading plugin: {settings['name']}")

    def register_extensions(self):
        self.csrf = CSRFProtect()
        self.csrf.init_app(self)
        self.socketio = SocketIO(self)

    def configure_login_manager(self):
        login_manager = LoginManager()
        login_manager.session_protection = "strong"
        login_manager.init_app(self)

        @login_manager.user_loader
        def user_loader(name):
            return db.get_user(name)

    def configure_terminal_socket(self):
        def send_data(session, file_descriptor):
            session_object = db.factory(
                "session",
                commit=True,
                name=session,
                timestamp=str(datetime.now()),
                **vs.ssh_sessions[session],
            )
            while True:
                try:
                    self.socketio.sleep(0.1)
                    output = read(file_descriptor, 1024).decode()
                    session_object.content += output
                    self.socketio.emit("output",
                                       output,
                                       namespace="/terminal",
                                       room=session)
                    db.session.commit()
                except OSError:
                    break

        @self.socketio.on("input", namespace="/terminal")
        def input(data):
            session = vs.ssh_sessions[request.args["session"]]
            write(session["file_descriptor"], data.encode())

        @self.socketio.on("join", namespace="/terminal")
        def on_join(session):
            join_room(session)

        @self.socketio.on("connect", namespace="/terminal")
        def connect():
            session_id = request.args["session"]
            session = vs.ssh_sessions.get(session_id)
            if not session:
                return
            device = db.fetch("device", id=session["device"])
            username, password = session["credentials"]
            address, options = getattr(device, session["form"]["address"]), ""
            if vs.settings["ssh"]["bypass_key_prompt"]:
                options = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
            process_id, session["file_descriptor"] = fork()
            if process_id:
                task = partial(send_data, session_id,
                               session["file_descriptor"])
                self.socketio.start_background_task(target=task)
            else:
                port = f"-p {device.port}"
                if session["form"]["protocol"] == "telnet":
                    command = f"telnet {address}"
                elif password:
                    ssh_command = f"sshpass -p {password} ssh {options}"
                    command = f"{ssh_command} {username}@{address} {port}"
                else:
                    command = f"ssh {options} {address} {port}"
                run(command.split())

    def configure_context_processor(self):
        @self.context_processor
        def inject_properties():
            user = current_user.serialized if current_user.is_authenticated else None
            return {
                "user": user,
                "time": str(vs.get_time()),
                **vs.template_context
            }

    def configure_errors(self):
        @self.errorhandler(403)
        def authorization_required(error):
            login_url = url_for("blueprint.route", page="login")
            return render_template("error.html",
                                   error=403,
                                   login_url=login_url), 403

        @self.errorhandler(404)
        def not_found_error(error):
            return render_template("error.html", error=404), 404

    @staticmethod
    def process_requests(function):
        @wraps(function)
        def decorated_function(*args, **kwargs):
            remote_address = request.environ["REMOTE_ADDR"]
            client_address = request.environ.get("HTTP_X_FORWARDED_FOR",
                                                 remote_address)
            rest_request = request.path.startswith("/rest/")
            endpoint = "/".join(request.path.split("/")[:2 + rest_request])
            request_property = f"{request.method.lower()}_requests"
            endpoint_rbac = vs.rbac[request_property].get(endpoint)
            if not current_user.is_authenticated:
                login_user(db.get_user("admin"))
            username = getattr(current_user, "name", "Unknown")
            if not endpoint_rbac:
                status_code = 404
            else:
                try:
                    result = function(*args, **kwargs)
                    status_code = 200
                except (db.rbac_error, Forbidden):
                    status_code = 403
                except NotFound:
                    status_code = 404
                except Exception:
                    status_code, traceback = 500, format_exc()
            log = (f"USER: {username} ({client_address}) - "
                   f"{request.method} {request.path} ({status_code})")
            if status_code == 500:
                log += f"\n{traceback}"
            env.log(Server.status_log_level[status_code],
                    log,
                    change_log=False)
            if status_code == 200:
                return result
            elif endpoint == "/login" or request.method == "GET" and not rest_request:
                if (not current_user.is_authenticated and not rest_request
                        and endpoint != "/login"):
                    url = url_for("blueprint.route",
                                  page="login",
                                  next_url=request.url)
                    return redirect(login_url(url))
                next_url = request.args.get("next_url")
                login_link = login_url(
                    url_for("blueprint.route", page="login",
                            next_url=next_url))
                return (
                    render_template("error.html",
                                    error=status_code,
                                    login_url=login_link),
                    status_code,
                )
            else:
                error_message = Server.status_error_message[status_code]
                alert = f"Error {status_code} - {error_message}"
                return jsonify({"alert": alert}), status_code

        return decorated_function

    def configure_routes(self):
        blueprint = Blueprint("blueprint",
                              __name__,
                              template_folder="../templates")

        @blueprint.route("/")
        @self.process_requests
        def site_root():
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/login", methods=["GET", "POST"])
        @self.process_requests
        def login():
            return redirect(url_for("blueprint.route", page="dashboard"))

        @blueprint.route("/dashboard")
        @self.process_requests
        def dashboard():
            return render_template(
                "dashboard.html",
                **{
                    "endpoint": "dashboard",
                    "properties": vs.properties["dashboard"]
                },
            )

        @blueprint.route("/logout")
        @self.process_requests
        def logout():
            logout_log = f"USER '{current_user.name}' logged out"
            logout_user()
            env.log("info", logout_log, logger="security")
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/<table_type>_table")
        @self.process_requests
        def table(table_type):
            return render_template(
                "table.html", **{
                    "endpoint": f"{table_type}_table",
                    "type": table_type
                })

        @blueprint.route("/view_builder")
        @blueprint.route("/logical_view")
        @blueprint.route("/geographical_view")
        @self.process_requests
        def view():
            return render_template("visualization.html",
                                   endpoint=request.path[1:])

        @blueprint.route("/workflow_builder")
        @self.process_requests
        def workflow_builder():
            return render_template("workflow.html",
                                   endpoint="workflow_builder")

        @blueprint.route("/<form_type>_form")
        @self.process_requests
        def form(form_type):
            form = vs.form_class[form_type](request.form)
            return render_template(
                f"forms/{getattr(form, 'template', 'base')}.html",
                **{
                    "endpoint": f"forms/{form_type}",
                    "action": getattr(form, "action", None),
                    "button_label": getattr(form, "button_label", "Confirm"),
                    "button_class": getattr(form, "button_class", "success"),
                    "form": form,
                    "form_type": form_type,
                },
            )

        @blueprint.route("/parameterized_form/<service_id>")
        @self.process_requests
        def parameterized_form(service_id):
            global_variables = {
                "form": None,
                "BaseForm": BaseForm,
                **vs.form_context
            }
            indented_form = "\n".join(" " * 4 + line for line in (
                f"form_type = HiddenField(default='initial-{service_id}')",
                *db.fetch("service",
                          id=service_id).parameterized_form.splitlines(),
            ))
            full_form = f"class Form(BaseForm):\n{indented_form}\nform = Form"
            try:
                exec(full_form, global_variables)
            except Exception:
                return (
                    "<div style='margin: 8px'>The parameterized form could not be  "
                    "loaded because of the following error:"
                    f"<br><pre>{format_exc()}</pre></div>")
            return render_template(
                "forms/base.html",
                **{
                    "form_type": f"initial-{service_id}",
                    "action": "eNMS.automation.submitInitialForm",
                    "button_label": "Run Service",
                    "button_class": "primary",
                    "form": global_variables["form"](request.form),
                },
            )

        @blueprint.route("/help/<path:path>")
        @self.process_requests
        def help(path):
            return render_template(f"help/{path}.html")

        @blueprint.route("/view_service_results/<int:run_id>/<int:service>")
        @self.process_requests
        def view_service_results(run_id, service):
            results = db.fetch_all("result", run_id=run_id, service_id=service)
            results_dict = [result.result for result in results]
            if not results_dict:
                return "No Results Found"
            return f"<pre>{vs.dict_to_string(results_dict)}</pre>"

        @blueprint.route("/download_file/<path:path>")
        @self.process_requests
        def download_file(path):
            return send_file(f"/{path}", as_attachment=True)

        @blueprint.route("/export_service/<int:id>")
        @self.process_requests
        def export_service(id):
            filename = f"/{controller.export_service(id)}.tgz"
            return send_file(filename, as_attachment=True)

        @blueprint.route("/terminal/<session>")
        @self.process_requests
        def ssh_connection(session):
            return render_template("terminal.html", session=session)

        @blueprint.route("/<path:_>")
        @self.process_requests
        def get_requests_sink(_):
            abort(404)

        @blueprint.route("/", methods=["POST"])
        @blueprint.route("/<path:page>", methods=["POST"])
        @self.process_requests
        def route(page):
            form_type = request.form.get("form_type")
            endpoint, *args = page.split("/")
            if request.json:
                kwargs = request.json
            elif form_type:
                form = vs.form_class[form_type](request.form)
                if not form.validate_on_submit():
                    return jsonify({
                        "invalid_form": True,
                        "errors": form.errors
                    })
                kwargs = form.form_postprocessing(request.form)
            else:
                kwargs = request.form
            with db.session_scope():
                return jsonify(getattr(controller, endpoint)(*args, **kwargs))

        self.register_blueprint(blueprint)
Beispiel #12
0
def csrf_protection(app):
    csrf = CSRFProtect()
    csrf.init_app(app)
Beispiel #13
0
    while True:
        sys.exit()

print('importing modules.....')

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask_mail import Mail
from flask_wtf.csrf import CSRFProtect
from ATS import skHandler

import os, sys
csrf = CSRFProtect()
site = Flask(__name__)
csrf.init_app(site)

print("The Python version is %s.%s.%s" % sys.version_info[:3])

#obtains secret key from .txt file. Will generate one if not found.
key = skHandler.sk()
key.run()

site.config['SECRET_KEY'] = (key.key)
site.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///ats.db'
site.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
#database

site.config['MAIL_SERVER'] = 'smtp.gmail.com'
site.config['MAIL_PORT'] = 587
site.config['MAIL_USE_TLS'] = True
Beispiel #14
0
def create_app(config_filename=None, **kwargs):
    """
    Application factory
    :param config_filename:
    :return:
    """
    # Initialize flask app
    app = Flask(__name__.split('.')[0], static_url_path='')

    # Import default configs
    app.config.from_object('sqmpy.defaults')

    # Configurations
    app.config.from_object('config')

    # Load the given config file
    if config_filename:
        app.config.from_pyfile(config_filename, silent=False)

    # Import from environment
    app.config.from_envvar('SQMPY_SETTINGS', silent=True)

    # Updated with keyword arguments
    app.config.update(kwargs)

    # Register app on db
    from sqmpy.database import db
    db.init_app(app)

    csrf = CSRFProtect()
    # Activate CSRF protection
    if app.config.get('CSRF_ENABLED'):
        csrf.init_app(app)

    # Register blueprints
    from sqmpy.security import security_blueprint
    from sqmpy.job import job_blueprint
    from sqmpy.views import main_blueprint
    app.register_blueprint(security_blueprint)
    app.register_blueprint(job_blueprint)
    app.register_blueprint(main_blueprint)

    if __debug__:
        # create_all should be called after models are imported. In current
        # code they are imported along with blue_print imports above.
        with app.app_context():
            db.create_all()

    # A global context processor for sub menu items
    @app.context_processor
    def make_navmenu_items():
        from flask import url_for
        if job_blueprint.name in app.blueprints:
            return {'navitems': {job_blueprint.name:
                                 url_for('%s.index' % job_blueprint.name)}}
        else:
            return {}

    @app.before_first_request
    def activate_job_monitor():
        thread = JobMonitorThread(kwargs={'app': app})
        app.monitor = thread
        thread.start()

    return app
Beispiel #15
0
def create_app():
    server = Flask(__name__, template_folder='../client/public')

    server.wsgi_app = ProxyFix(server.wsgi_app, x_for=1)
    server.config.from_object("server." + os.environ["APP_SETTINGS"])

    from .db import db
    from .limiter import limiter
    csrf = CSRFProtect()
    db.init_app(server)
    csrf.init_app(server)
    limiter.init_app(server)
    Talisman(server,
             content_security_policy={
                 'font-src':
                 ["'self'", 'themes.googleusercontent.com', '*.gstatic.com'],
                 'script-src': ["'self'", 'ajax.googleapis.com'],
                 'style-src': [
                     "'self'",
                     'fonts.googleapis.com',
                     '*.gstatic.com',
                     'ajax.googleapis.com',
                     "'unsafe-inline'",
                 ],
                 'default-src': ["'self'", '*.gstatic.com']
             },
             force_https=False)

    from server.models.Client_Resources import Client_Resources
    from server.models.Gallery_Info import Gallery_Info
    from server.models.Image import Image
    from server.models.Paragraph import Paragraph
    from server.models.Galleries import Galleries
    from server.models.Header import Header

    # Order of imports matter here
    # If placed at the top, these blueprints get no db :(
    from server.blueprints.content import content
    from server.blueprints.auth import auth
    from server.blueprints.images import images
    from server.blueprints.blacklist import blacklist
    from server.blueprints.settings import settings
    from server.blueprints.email_service import email_service
    from server.blueprints.gallery import galleries
    from server.blueprints.users import users

    server.register_blueprint(auth, url_prefix="/admin")
    server.register_blueprint(images, url_prefix='/admin/assets')
    server.register_blueprint(content, url_prefix='/content')
    server.register_blueprint(galleries, url_prefix='/galleries')
    server.register_blueprint(blacklist, url_prefix='/admin/blacklist')
    server.register_blueprint(settings)
    server.register_blueprint(email_service)
    server.register_blueprint(users)

    # Set custom CSP settings for admin portal, no easier way to do this unfortunately
    # https://github.com/GoogleCloudPlatform/flask-talisman/issues/45
    with server.app_context():
        setattr(
            current_app.view_functions.get("auth.home"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("auth.login"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("images.handle_images"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("content.render"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("galleries.handle_images"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("settings.show_page"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})
        setattr(
            current_app.view_functions.get("users.handle_recovery"),
            "talisman_view_options",
            {"content_security_policy": {
                "default-src": "* 'unsafe-inline'"
            }})

    @server.errorhandler(429)
    def handle_excess_req(e):
        message = "You've requested our site quite rapidly recently. Please try again later."
        error = "Too Many Requests"
        return render_template('error.html', message=message, error=error)

    @server.route('/', methods=["GET"])
    @limiter.limit('50/minute')
    def home():
        return render_template('index.html', test=request.headers)

        # return send_from_directory('../client/public', 'index.html')

    @server.route("/<path:path>")
    @limiter.exempt
    def send_assets(path):
        return send_from_directory('../client/public', path)

    @server.route('/resources', methods=['GET'])
    def get_resources():
        """
        ### Deliver only the content and galleries that are in use on the client website

        ```sql
        SELECT clients.resource_id, headers.header_text, clients.content_id, paragraphs.paragraph_text, images.image_name, images.image_link, clients.gallery_id, gallery_info.gallery_name, gallery_info.description, galleries.index_id FROM client_resources as clients 
        LEFT OUTER JOIN headers ON headers.header_id = clients.content_id
        LEFT OUTER JOIN galleries ON galleries.info_id = clients.gallery_id
        LEFT OUTER JOIN paragraphs ON paragraphs.paragraph_id = headers.paragraph_id
        LEFT OUTER JOIN images ON images.image_id = headers.image_id OR images.image_id = galleries.image_id
        LEFT OUTER JOIN gallery_info ON galleries.info_id = gallery_info.gallery_id
        ORDER BY clients.content_id, clients.gallery_id, galleries.index_id, clients.resource_id;
        ```
        """
        resources = Client_Resources.query.with_entities(
            Client_Resources.resource_id, Client_Resources.content_id,
            Header.header_text, Paragraph.paragraph_text, Image.image_name,
            Image.image_link, Client_Resources.gallery_id,
            Gallery_Info.gallery_name, Gallery_Info.description,
            Galleries.index_id).outerjoin(
                Header,
                Header.header_id == Client_Resources.content_id).outerjoin(
                    Galleries, Galleries.info_id == Client_Resources.gallery_id
                ).outerjoin(
                    Paragraph,
                    Paragraph.paragraph_id == Header.paragraph_id).outerjoin(
                        Image,
                        ((Image.image_id == Header.image_id) |
                         (Image.image_id == Galleries.image_id))).outerjoin(
                             Gallery_Info, Gallery_Info.gallery_id ==
                             Galleries.info_id).order_by(
                                 Client_Resources.content_id,
                                 Client_Resources.gallery_id,
                                 Galleries.index_id,
                                 Client_Resources.resource_id)

        all_galleries = []
        all_content = []

        marker = None
        gallery_json = None

        for row in resources:
            if row.content_id is not None:
                image_name = row.image_name if hasattr(
                    row, "image_name") else 'Placeholder'
                image_link = row.image_link if hasattr(
                    row, "image_link") else url_for(
                        'static',
                        filename='assets/icons/image-icon.inkscape.png')
                item = {
                    "header_text": row.header_text,
                    "paragraph_text": row.paragraph_text,
                    "image_name": image_name,
                    "image_link": image_link
                }
                all_content.append(item)

            if row.gallery_id is not None:
                if marker is None and gallery_json is None:
                    marker = row.gallery_id
                    gallery_json = {
                        "gallery_name": row.gallery_name,
                        "description": row.description,
                        "images": []
                    }
                # Change in row.info_id means the current row is part of next gallery
                # Append prev gallery and set up next one, change pointer to new current gallery
                if row.gallery_id != marker:
                    all_galleries.append(gallery_json)
                    gallery_json = {
                        "gallery_name": row.gallery_name,
                        "description": row.description,
                        "images": []
                    }
                    marker = row.gallery_id

                image_link = row.image_link
                image_name = row.image_name
                gallery_json['images'].append({
                    "alt": image_name,
                    "src": image_link
                })
                # Append the last gallery after adding the last image to it
                if row == resources[-1]:
                    all_galleries.append(gallery_json)

        all_stuff = {"content": all_content, "galleries": all_galleries}

        return jsonify(all_stuff)

    return server
def create_app():
    # アプリケーションの設定
    app = Flask(__name__, instance_relative_config=True)

    configs = {
        'production': 'ProductionConfig',
        'development': 'DevelopmentConfig',
        'testing': 'TestingConfig'
    }
    flask_env = os.environ.get('FLASK_ENV', default='production')
    app.config.from_object('application.config.{}'.format(configs[flask_env]))
    app.config.from_pyfile('application.cfg', silent=True)
    os.makedirs(app.instance_path, exist_ok=True)

    # ロガー設定
    os.makedirs(app.config['LOG_DIR'], exist_ok=True)
    log_file = os.path.join(app.config['LOG_DIR'], app.config['LOG_FILE_NAME'])
    handler = logging.handlers.RotatingFileHandler(log_file, "a+",
                                                   maxBytes=3000,
                                                   backupCount=5)
    handler.setFormatter(logging.Formatter(app.config['LOG_FORMAT']))
    app.logger.addHandler(handler)

    # データベース設定
    db.init_app(app)

    # データベースマイグレーション設定
    Migrate(app, db, directory=app.config['MIGRATIONS_DIR'])

    # from application.models.blog_post import BlogPost
    # from application.models.bookmark_post import BookmarkPost
    # from application.models.profile import Profile

    # ログインセッション管理設定
    login_manager = LoginManager()
    login_manager.init_app(app)
    login_manager.login_view = 'auth_view.login'
    login_manager.login_message = 'ログインしてください'
    login_manager.login_message_category = 'danger'
    @login_manager.user_loader
    def load_user(user_id):
        return User.query.filter(User.id == user_id).first()

    # CSRF設定
    csrf = CSRFProtect()
    csrf.init_app(app)

    # view登録
    app.register_blueprint(auth_view)
    app.register_blueprint(profile_view)
    app.register_blueprint(blog_view)
    app.register_blueprint(post_view)
    app.register_blueprint(home_view)
    app.register_blueprint(bookmark_view)

    @app.template_filter('base64')
    def base64_filter(val_bin):
        val_base64 = base64.b64encode(val_bin)
        val_str = val_base64.decode("ascii")
        return val_str

    @app.template_filter('is_bookmarked')
    def is_bookmarked_filter(post, bookmarks):
        for bookmark in bookmarks:
            if post.id == bookmark.bookmark_post_id:
                return True
        return False

    @app.template_filter('format_date')
    def format_date_filter(value, format_='%b. %d %Y'):
        return value.strftime(format_)

    @app.after_request
    def add_header(r):
        r.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
        r.headers["Pragma"] = "no-cache"
        r.headers["Expires"] = "0"
        r.headers['Cache-Control'] = 'public, max-age=0'
        return r

    return app
Beispiel #17
0
APP.config['MONGODB_SETTINGS'] = json_settings[
    environ["analyzer_env"]]["web_mongo"]
APP.config['SESSION_COOKIE_SAMESITE'] = "Lax"
QUEUE = QBQueue("analyzer",
                json_settings[environ["analyzer_env"]]["redis_settings"])
ANALYZER_TIMEOUT = json_settings[environ["analyzer_env"]]["analyzer_timeout"]
FUNCTION_TIMEOUT = json_settings[environ["analyzer_env"]]["function_timeout"]
MALWARE_FOLDER = json_settings[environ["analyzer_env"]]["malware_folder"]

MONGO_DB = MongoEngine()
MONGO_DB.init_app(APP)
BCRYPT = Bcrypt(APP)
LOGIN_MANAGER = LoginManager()
LOGIN_MANAGER.setup_app(APP)
CSRF = CSRFProtect()
CSRF.init_app(APP)
Markdown(APP)


class Namespace:
    '''
    this namespace for switches
    '''
    def __init__(self, kwargs):
        self.__dict__.update(kwargs)


def convert_size(_size):
    '''
    convert size to unit
    '''
Beispiel #18
0
    def run(self):
        app = Flask(__name__)
        app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
        csrf = CSRFProtect()
        import os
        SECRET_KEY = os.urandom(32)
        app.config['SECRET_KEY'] = SECRET_KEY
        csrf.init_app(app)

        @app.route('/')
        def index():

            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = True
            if self.current_compute_context == None or self.current_compute_context == '':
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=False,
                    message=
                    'Warning. No compute context is set. please set one with <a href="/context">/context</a>'
                )

            if self.control.state() == ReducerState.setup:
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=refresh,
                    message=
                    'Warning. Reducer is not base-configured. please do so with config file.'
                )

            return render_template('index.html',
                                   client=client,
                                   state=state,
                                   logs=logs,
                                   refresh=refresh)

        # http://localhost:8090/add?name=combiner&address=combiner&port=12080&token=e9a3cb4c5eaff546eec33ff68a7fbe232b68a192
        @app.route('/add')
        def add():
            """ Add a combiner to the network. """
            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})

            # TODO check for get variables
            name = request.args.get('name', None)
            address = str(request.args.get('address', None))
            port = request.args.get('port', None)
            # token = request.args.get('token')
            # TODO do validation

            if port is None or address is None or name is None:
                return "Please specify correct parameters."

            # Try to retrieve combiner from db
            combiner = self.control.network.get_combiner(name)
            if not combiner:
                # Create a new combiner
                import base64
                certificate, key = self.certificate_manager.get_or_create(
                    address).get_keypair_raw()
                cert_b64 = base64.b64encode(certificate)
                key_b64 = base64.b64encode(key)

                # TODO append and redirect to index.
                import copy
                combiner = CombinerInterface(self, name, address, port,
                                             copy.deepcopy(certificate),
                                             copy.deepcopy(key),
                                             request.remote_addr)
                self.control.network.add_combiner(combiner)

            combiner = self.control.network.get_combiner(name)

            ret = {
                'status': 'added',
                'certificate': combiner['certificate'],
                'key': combiner['key'],
                'storage': self.control.statestore.get_storage_backend(),
                'statestore': self.control.statestore.get_config(),
            }

            return jsonify(ret)

        @app.route('/history', methods=['GET', 'POST'])
        def history():
            if request.method == 'POST':
                # upload seed file
                uploaded_seed = request.files['seed']
                if uploaded_seed:
                    from io import BytesIO
                    a = BytesIO()
                    a.seek(0, 0)
                    uploaded_seed.seek(0)
                    a.write(uploaded_seed.read())
                    helper = self.control.get_helper()
                    model = helper.load_model_from_BytesIO(a.getbuffer())
                    self.control.commit(uploaded_seed.filename, model)
            else:
                h_latest_model_id = self.control.get_latest_model()
                model_info = self.control.get_model_info()
                return render_template('index.html',
                                       h_latest_model_id=h_latest_model_id,
                                       seed=True,
                                       model_info=model_info)

            seed = True
            return redirect(url_for('history', seed=seed))

        @app.route('/delete_model_trail', methods=['GET', 'POST'])
        def delete_model_trail():
            if request.method == 'POST':
                from fedn.common.tracer.mongotracer import MongoTracer
                statestore_config = self.control.statestore.get_config()
                self.tracer = MongoTracer(statestore_config['mongo_config'],
                                          statestore_config['network_id'])
                try:
                    self.control.drop_models()
                except:
                    pass

                # drop objects in minio
                self.control.delete_bucket_objects()
                return redirect(url_for('history'))
            seed = True
            return redirect(url_for('history', seed=seed))

        @app.route('/drop_control', methods=['GET', 'POST'])
        def drop_control():
            if request.method == 'POST':
                self.control.statestore.drop_control()
                return redirect(url_for('start'))
            return redirect(url_for('start'))

        # http://localhost:8090/start?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548
        @app.route('/start', methods=['GET', 'POST'])
        def start():
            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = True
            try:
                self.current_compute_context = self.control.get_compute_context(
                )
            except:
                self.current_compute_context = None

            if self.current_compute_context == None or self.current_compute_context == '':
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=False,
                    message=
                    'No compute context is set. Please set one here <a href="/context">/context</a>'
                )

            if self.control.state() == ReducerState.setup:
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=refresh,
                    message=
                    'Warning. Reducer is not base-configured. please do so with config file.'
                )

            if request.method == 'POST':

                timeout = float(request.form.get('timeout'))
                rounds = int(request.form.get('rounds', 1))
                task = (request.form.get('task', ''))
                clients_required = request.form.get('clients_required', 1)
                clients_requested = request.form.get('clients_requested', 8)

                #TODO: Enable in UI
                validate = request.form.get('validate', True)

                latest_model_id = self.control.get_latest_model()

                config = {
                    'round_timeout': timeout,
                    'model_id': latest_model_id,
                    'rounds': rounds,
                    'clients_required': clients_required,
                    'clients_requested': clients_requested,
                    'task': task,
                    'validate': validate
                }

                self.control.instruct(config)
                return redirect(
                    url_for('index', message="Sent execution plan."))

            else:
                # Select rounds UI
                rounds = range(1, 200)
                latest_model_id = self.control.get_latest_model()
                return render_template(
                    'index.html',
                    round_options=rounds,
                    latest_model_id=latest_model_id,
                    compute_package=self.current_compute_context,
                    helper=self.control.statestore.get_framework(),
                    validate=True)

            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = False
            return render_template('index.html',
                                   client=client,
                                   state=state,
                                   logs=logs,
                                   refresh=refresh)

        @app.route('/assign')
        def assign():
            """Handle client assignment requests. """

            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})

            name = request.args.get('name', None)
            combiner_preferred = request.args.get('combiner', None)

            if combiner_preferred:
                combiner = self.control.find(combiner_preferred)
            else:
                combiner = self.control.find_available_combiner()

            client = {
                'name': name,
                'combiner_preferred': combiner_preferred,
                'ip': request.remote_addr,
                'status': 'available'
            }
            self.control.network.add_client(client)

            if combiner:
                import base64
                cert_b64 = base64.b64encode(combiner.certificate)
                response = {
                    'status': 'assigned',
                    'host': combiner.address,
                    'port': combiner.port,
                    'certificate': str(cert_b64).split('\'')[1],
                    'model_type': self.control.statestore.get_framework()
                }

                return jsonify(response)
            elif combiner is None:
                return jsonify({'status': 'retry'})

            return jsonify({'status': 'retry'})

        @app.route('/infer')
        def infer():
            if self.control.state() == ReducerState.setup:
                return "Error, not configured"
            result = ""
            try:
                self.control.set_model_id()
            except fedn.exceptions.ModelError:
                print("Failed to seed control.")

            return result

        def combiner_stats():
            combiner_info = []
            for combiner in self.control.network.get_combiners():
                try:
                    report = combiner.report()
                    combiner_info.append(report)
                except:
                    pass
                return combiner_info
            return False

        def create_map():
            cities_dict = {
                'city': [],
                'lat': [],
                'lon': [],
                'country': [],
                'name': [],
                'role': [],
                'size': []
            }

            from fedn import get_data
            dbpath = get_data('geolite2/GeoLite2-City.mmdb')

            with geoip2.database.Reader(dbpath) as reader:
                for combiner in self.control.statestore.list_combiners():
                    try:
                        response = reader.city(combiner['ip'])
                        cities_dict['city'].append(response.city.name)

                        r = 1.0  # Rougly 100km
                        w = r * math.sqrt(numpy.random.random())
                        t = 2.0 * math.pi * numpy.random.random()
                        x = w * math.cos(t)
                        y = w * math.sin(t)
                        lat = str(float(response.location.latitude) + x)
                        lon = str(float(response.location.longitude) + y)
                        cities_dict['lat'].append(lat)
                        cities_dict['lon'].append(lon)

                        cities_dict['country'].append(
                            response.country.iso_code)

                        cities_dict['name'].append(combiner['name'])
                        cities_dict['role'].append('Combiner')
                        cities_dict['size'].append(10)

                    except geoip2.errors.AddressNotFoundError as err:
                        print(err)

            with geoip2.database.Reader(dbpath) as reader:
                for client in self.control.statestore.list_clients():
                    try:
                        response = reader.city(client['ip'])
                        cities_dict['city'].append(response.city.name)
                        cities_dict['lat'].append(response.location.latitude)
                        cities_dict['lon'].append(response.location.longitude)
                        cities_dict['country'].append(
                            response.country.iso_code)

                        cities_dict['name'].append(client['name'])
                        cities_dict['role'].append('Client')
                        # TODO: Optionally relate to data size
                        cities_dict['size'].append(6)

                    except geoip2.errors.AddressNotFoundError as err:
                        print(err)

            config = self.control.statestore.get_config()

            cities_df = pd.DataFrame(cities_dict)

            fig = px.scatter_geo(cities_df,
                                 lon="lon",
                                 lat="lat",
                                 projection="natural earth",
                                 color="role",
                                 size="size",
                                 hover_name="city",
                                 hover_data={
                                     "city": False,
                                     "lon": False,
                                     "lat": False,
                                     'size': False,
                                     'name': True,
                                     'role': True
                                 })

            fig.update_geos(fitbounds="locations", showcountries=True)
            fig.update_layout(
                title="FEDn network: {}".format(config['network_id']))

            fig = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
            return fig

        @app.route('/dashboard')
        def dashboard():
            from fedn.clients.reducer.plots import Plot
            plot = Plot(self.control.statestore)
            box_plot = plot.create_box_plot()
            table_plot = plot.create_table_plot()
            #timeline_plot = plot.create_timeline_plot()
            timeline_plot = None
            clients_plot = plot.create_client_plot()
            return render_template(
                'index.html',
                show_plot=True,
                box_plot=box_plot,
                table_plot=table_plot,
                timeline_plot=timeline_plot,
                clients_plot=clients_plot,
            )

        @app.route('/network')
        def network():
            from fedn.clients.reducer.plots import Plot
            plot = Plot(self.control.statestore)
            round_time_plot = plot.create_round_plot()
            mem_cpu_plot = plot.create_cpu_plot()
            combiners_plot = plot.create_combiner_plot()
            map_plot = create_map()
            combiner_info = combiner_stats()
            return render_template('index.html',
                                   map_plot=map_plot,
                                   network_plot=True,
                                   round_time_plot=round_time_plot,
                                   mem_cpu_plot=mem_cpu_plot,
                                   combiners_plot=combiners_plot,
                                   combiner_info=combiner_info)

        @app.route('/context', methods=['GET', 'POST'])
        @csrf.exempt  # TODO fix csrf token to form posting in package.py
        def context():
            # if self.control.state() != ReducerState.setup or self.control.state() != ReducerState.idle:
            #    return "Error, Context already assigned!"
            reset = request.args.get(
                'reset',
                None)  #if reset is not empty then allow context re-set
            if reset:
                return render_template('context.html')

            if request.method == 'POST':

                if 'file' not in request.files:
                    flash('No file part')
                    return redirect(request.url)

                file = request.files['file']
                # if user does not select file, browser also
                # submit an empty part without filename
                if file.filename == '':
                    flash('No selected file')
                    return redirect(request.url)

                if file and allowed_file(file.filename):
                    filename = secure_filename(file.filename)
                    file_path = os.path.join(app.config['UPLOAD_FOLDER'],
                                             filename)
                    file.save(file_path)

                    if self.control.state(
                    ) == ReducerState.instructing or self.control.state(
                    ) == ReducerState.monitoring:
                        return "Not allowed to change context while execution is ongoing."

                    self.control.set_compute_context(filename, file_path)
                    return redirect(url_for('start'))

            from flask import send_from_directory
            name = request.args.get('name', '')

            if name == '':
                name = self.control.get_compute_context()
                if name == None or name == '':
                    return render_template('context.html')

            # There is a potential race condition here, if one client requests a package and at
            # the same time another one triggers a fetch from Minio and writes to disk.
            try:
                mutex = Lock()
                mutex.acquire()
                return send_from_directory(app.config['UPLOAD_FOLDER'],
                                           name,
                                           as_attachment=True)
            except:
                try:
                    data = self.control.get_compute_package(name)
                    file_path = os.path.join(app.config['UPLOAD_FOLDER'], name)
                    with open(file_path, 'wb') as fh:
                        fh.write(data)
                    return send_from_directory(app.config['UPLOAD_FOLDER'],
                                               name,
                                               as_attachment=True)
                except:
                    raise
            finally:
                mutex.release()

            return render_template('context.html')

        if self.certificate:
            print("trying to connect with certs {} and key {}".format(
                str(self.certificate.cert_path),
                str(self.certificate.key_path)),
                  flush=True)
            app.run(host="0.0.0.0",
                    port="8090",
                    ssl_context=(str(self.certificate.cert_path),
                                 str(self.certificate.key_path)))
Beispiel #19
0
def init_csrf(app):
    #前后端不分离可用此方法
    app.config['SECRET_KEY'] = 'you never guess'
    csrf = CSRFProtect(app)
    csrf.init_app(app)
from flask import Flask
from flask_wtf.csrf import CSRFProtect

app = Flask(__name__)
app.config['SECRET_KEY'] = 'top-secret!'
csrf = CSRFProtect()
csrf.init_app(app)  # Compliant


@app.route('/csrftest1/', methods=['POST'])
@csrf.exempt  # Noncompliant {{Make sure disabling CSRF protection is safe here.}}
#     ^^^^^^
def csrftestpost():
    pass


# Corner cases for test coverage
@csrf.thisDoesntExist  # Compliant
def ok1():
    pass


@unrelatedDecorator.exempt
def ok2():
    pass
Beispiel #21
0
from flask_wtf.csrf import CSRFProtect
import os
# Database
db = SQLAlchemy()

# CSRF protector
csrf = CSRFProtect()

# Login Manager for Users
login_manager = LoginManager()

# Creating flask app for use throughout
application = Flask(__name__)
application.config.from_object(os.environ.get('APP_SETTINGS'))
application.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
csrf.init_app(application)
db.init_app(application)

# Setting up login manager for app
login_manager.init_app(application)
login_manager.login_message = "Login is required to enter dashboard"
login_manager.login_view = "user.login"

# Sets up Bootstrap to allow for wtf forms
Bootstrap(application)

# Registering blueprints for app
from .users import user as user_blueprint
application.register_blueprint(user_blueprint)

from .health import health as health_blueprint
Beispiel #22
0
    def run(self):
        app = Flask(__name__)
        app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
        csrf = CSRFProtect()
        import os
        SECRET_KEY = os.urandom(32)
        app.config['SECRET_KEY'] = SECRET_KEY

        csrf.init_app(app)

        c = pymongo.MongoClient()
        mc = pymongo.MongoClient(os.environ['FEDN_MONGO_HOST'],
                                 int(os.environ['FEDN_MONGO_PORT']),
                                 username=os.environ['FEDN_MONGO_USER'],
                                 password=os.environ['FEDN_MONGO_PASSWORD'])
        mdb = mc[os.environ['ALLIANCE_UID']]
        alliance = mdb["status"]

        @app.route('/')
        def index():
            # logs_fancy = str()
            # for log in self.logs:
            #    logs_fancy += "<p>" + log + "</p>\n"
            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = True
            if self.control.state() == ReducerState.setup:
                return render_template(
                    'setup.html',
                    client=client,
                    state=state,
                    logs=logs,
                    refresh=refresh,
                    dashboardhost=os.environ["FEDN_DASHBOARD_HOST"],
                    dashboardport=os.environ["FEDN_DASHBOARD_PORT"])

            return render_template(
                'index.html',
                client=client,
                state=state,
                logs=logs,
                refresh=refresh,
                dashboardhost=os.environ["FEDN_DASHBOARD_HOST"],
                dashboardport=os.environ["FEDN_DASHBOARD_PORT"])

        # http://localhost:8090/add?name=combiner&address=combiner&port=12080&token=e9a3cb4c5eaff546eec33ff68a7fbe232b68a192
        @app.route('/add')
        def add():
            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})
            # TODO check for get variables
            name = request.args.get('name', None)
            address = request.args.get('address', None)
            port = request.args.get('port', None)
            # token = request.args.get('token')
            # TODO do validation

            if port is None or address is None or name is None:
                return "Please specify correct parameters."

            certificate, key = self.certificate_manager.get_or_create(
                address).get_keypair_raw()
            import base64
            cert_b64 = base64.b64encode(certificate)
            key_b64 = base64.b64encode(key)

            # TODO append and redirect to index.
            import copy
            combiner = CombinerInterface(self, name, address, port,
                                         copy.deepcopy(certificate),
                                         copy.deepcopy(key))
            self.control.add(combiner)

            ret = {
                'status': 'added',
                'certificate': str(cert_b64).split('\'')[1],
                'key': str(key_b64).split('\'')[1]
            }  # TODO remove ugly string hack
            return jsonify(ret)

        @app.route('/seed', methods=['GET', 'POST'])
        def seed():
            if request.method == 'POST':
                # upload seed file
                uploaded_seed = request.files['seed']
                if uploaded_seed:
                    self.control.commit(uploaded_seed.filename, uploaded_seed)
            else:
                h_latest_model_id = self.control.get_latest_model()
                model_info = self.control.get_model_info()
                return render_template('index.html',
                                       h_latest_model_id=h_latest_model_id,
                                       seed=True,
                                       model_info=model_info)

            seed = True
            return redirect(url_for('seed', seed=seed))

        # http://localhost:8090/start?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548
        @app.route('/start', methods=['GET', 'POST'])
        def start():
            if self.control.state() == ReducerState.setup:
                return "Error, not configured"

            if request.method == 'POST':
                timeout = request.form.get('timeout', 180)
                rounds = int(request.form.get('rounds', 1))

                task = (request.form.get('task', ''))
                active_clients = request.form.get('active_clients', 2)
                clients_required = request.form.get('clients_required', 2)
                clients_requested = request.form.get('clients_requested', 8)

                latest_model_id = self.control.get_latest_model()
                config = {
                    'round_timeout': timeout,
                    'model_id': latest_model_id,
                    'rounds': rounds,
                    'active_clients': active_clients,
                    'clients_required': clients_required,
                    'clients_requested': clients_requested,
                    'task': task
                }

                self.control.instruct(config)
                return redirect(
                    url_for('index', message="Sent execution plan."))

            else:
                # Select rounds UI
                rounds = range(1, 100)
                latest_model_id = self.control.get_latest_model()
                return render_template('index.html',
                                       round_options=rounds,
                                       latest_model_id=latest_model_id)

            client = self.name
            state = ReducerStateToString(self.control.state())
            logs = None
            refresh = False
            return render_template('index.html',
                                   client=client,
                                   state=state,
                                   logs=logs,
                                   refresh=refresh)

        @app.route('/assign')
        def assign():
            if self.control.state() == ReducerState.setup:
                return jsonify({'status': 'retry'})
            name = request.args.get('name', None)
            combiner_preferred = request.args.get('combiner', None)
            import uuid
            id = str(uuid.uuid4())

            if combiner_preferred:
                combiner = self.control.find(combiner_preferred)
            else:
                combiner = self.control.find_available_combiner()

            if combiner:
                # certificate, _ = self.certificate_manager.get_or_create(combiner.name).get_keypair_raw()
                import base64
                cert_b64 = base64.b64encode(combiner.certificate)
                response = {
                    'status': 'assigned',
                    'host': combiner.address,
                    'port': combiner.port,
                    'certificate': str(cert_b64).split('\'')[1]
                }

                return jsonify(response)
            elif combiner is None:
                return jsonify({'status': 'retry'})
                #abort(404, description="Resource not found")
            # 1.receive client parameters
            # 2. check with available combiners if any clients are needed
            # 3. let client know where to connect.
            return jsonify({'status': 'retry'})

        @app.route('/infer')
        def infer():
            if self.control.state() == ReducerState.setup:
                return "Error, not configured"
            result = ""
            try:
                self.control.set_model_id()
            except fedn.exceptions.ModelError:
                print("Failed to seed control.")

            return result

        # plot metrics from DB
        def _scalar_metrics(metrics):
            """ Extract all scalar valued metrics from a MODEL_VALIDATON. """

            data = json.loads(metrics['data'])
            data = json.loads(data['data'])

            valid_metrics = []
            for metric, val in data.items():
                # If it can be converted to a float it is a valid, scalar metric
                try:
                    val = float(val)
                    valid_metrics.append(metric)
                except:
                    pass

            return valid_metrics

        @app.route('/plot')
        def plot():
            box = 'box'
            plot = create_plot(box)
            show_plot = True
            return render_template('index.html',
                                   show_plot=show_plot,
                                   plot=plot)

        def create_plot(feature):
            if feature == 'table':
                return create_table_plot()
            elif feature == 'timeline':
                return create_timeline_plot()
            elif feature == 'ml':
                return create_ml_plot()
            elif feature == 'box':
                return create_box_plot()
            else:
                return 'No plot!'

        @app.route('/plot_type', methods=['GET', 'POST'])
        def change_features():
            feature = request.args['selected']
            graphJSON = create_plot(feature)
            return graphJSON

        def create_table_plot():
            metrics = alliance.find_one({'type': 'MODEL_VALIDATION'})
            if metrics == None:
                fig = go.Figure(data=[])
                fig.update_layout(
                    title_text='No data currently available for mean metrics')
                table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
                return table

            valid_metrics = _scalar_metrics(metrics)
            if valid_metrics == []:
                fig = go.Figure(data=[])
                fig.update_layout(title_text='No scalar metrics found')
                table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
                return table

            all_vals = []
            models = []
            for metric in valid_metrics:
                validations = {}
                for post in alliance.find({'type': 'MODEL_VALIDATION'}):
                    e = json.loads(post['data'])
                    try:
                        validations[e['modelId']].append(
                            float(json.loads(e['data'])[metric]))
                    except KeyError:
                        validations[e['modelId']] = [
                            float(json.loads(e['data'])[metric])
                        ]

                vals = []
                models = []
                for model, data in validations.items():
                    vals.append(numpy.mean(data))
                    models.append(model)
                all_vals.append(vals)

            header_vals = valid_metrics
            models.reverse()
            values = [models]
            print(all_vals, flush=True)
            for vals in all_vals:
                vals.reverse()
                values.append(vals)

            fig = go.Figure(data=[
                go.Table(
                    header=dict(values=['Model ID'] + header_vals,
                                line_color='darkslategray',
                                fill_color='lightskyblue',
                                align='left'),
                    cells=dict(
                        values=values,  # 2nd column
                        line_color='darkslategray',
                        fill_color='lightcyan',
                        align='left'))
            ])

            fig.update_layout(title_text='Summary: mean metrics')
            table = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
            return table

        def create_timeline_plot():
            trace_data = []
            x = []
            y = []
            base = []
            for p in alliance.find({'type': 'MODEL_UPDATE_REQUEST'}):
                e = json.loads(p['data'])
                cid = e['correlationId']
                for cc in alliance.find({
                        'sender': p['sender'],
                        'type': 'MODEL_UPDATE'
                }):
                    da = json.loads(cc['data'])
                    if da['correlationId'] == cid:
                        cp = cc

                cd = json.loads(cp['data'])
                tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f')
                tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f')
                ts = tu - tr
                base.append(tr.timestamp())
                x.append(ts.total_seconds())
                y.append(p['sender']['name'])

            trace_data.append(
                go.Bar(
                    x=x,
                    y=y,
                    orientation='h',
                    base=base,
                    marker=dict(color='royalblue'),
                    name="Training",
                ))

            x = []
            y = []
            base = []
            for p in alliance.find({'type': 'MODEL_VALIDATION_REQUEST'}):
                e = json.loads(p['data'])
                cid = e['correlationId']
                for cc in alliance.find({
                        'sender': p['sender'],
                        'type': 'MODEL_VALIDATION'
                }):
                    da = json.loads(cc['data'])
                    if da['correlationId'] == cid:
                        cp = cc
                cd = json.loads(cp['data'])
                tr = datetime.strptime(e['timestamp'], '%Y-%m-%d %H:%M:%S.%f')
                tu = datetime.strptime(cd['timestamp'], '%Y-%m-%d %H:%M:%S.%f')
                ts = tu - tr
                base.append(tr.timestamp())
                x.append(ts.total_seconds())
                y.append(p['sender']['name'])

            trace_data.append(
                go.Bar(
                    x=x,
                    y=y,
                    orientation='h',
                    base=base,
                    marker=dict(color='lightskyblue'),
                    name="Validation",
                ))

            layout = go.Layout(
                barmode='stack',
                showlegend=True,
            )

            fig = go.Figure(data=trace_data, layout=layout)
            fig.update_xaxes(title_text='Timestamp')
            fig.update_layout(title_text='Alliance timeline')

            timeline = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
            return timeline

        def create_ml_plot():
            metrics = alliance.find_one({'type': 'MODEL_VALIDATION'})
            if metrics == None:
                fig = go.Figure(data=[])
                fig.update_layout(
                    title_text=
                    'No data currently available for Mean Absolute Error')
                ml = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
                return ml

            data = json.loads(metrics['data'])
            data = json.loads(data['data'])
            valid_metrics = []
            for metric, val in data.items():
                # Check if scalar - is this robust ?
                if isinstance(val, float):
                    valid_metrics.append(metric)

            # Assemble a dict with all validations
            validations = {}
            clients = {}

            for post in alliance.find({'type': 'MODEL_VALIDATION'}):
                try:
                    e = json.loads(post['data'])
                    clients[post['sender']['name']].append(
                        json.loads(e['data'])[metric])
                except KeyError:
                    clients[post['sender']['name']] = []

            rounds = []
            traces_data = []

            for c in clients:
                print(clients[c], flush=True)
                traces_data.append(go.Scatter(x=rounds, y=clients[c], name=c))
            fig = go.Figure(traces_data)
            fig.update_xaxes(title_text='Rounds')
            fig.update_yaxes(title_text='MAE',
                             tickvals=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
            fig.update_layout(title_text='Mean Absolute Error Plot')
            ml = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
            return ml

        def create_box_plot():
            metrics = alliance.find_one({'type': 'MODEL_VALIDATION'})
            if metrics == None:
                fig = go.Figure(data=[])
                fig.update_layout(
                    title_text=
                    'No data currently available for metric distribution over alliance '
                    'participants')
                box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
                return box

            valid_metrics = _scalar_metrics(metrics)
            if valid_metrics == []:
                fig = go.Figure(data=[])
                fig.update_layout(title_text='No scalar metrics found')
                box = json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
                return box

            # Just grab the first metric in the list.
            # TODO: Let the user choose, or plot all of them.
            if "accuracy" in valid_metrics:
                metric = "accuracy"
            else:
                metric = valid_metrics[0]
            validations = {}
            for post in alliance.find({'type': 'MODEL_VALIDATION'}):
                e = json.loads(post['data'])
                try:
                    validations[e['modelId']].append(
                        float(json.loads(e['data'])[metric]))
                except KeyError:
                    validations[e['modelId']] = [
                        float(json.loads(e['data'])[metric])
                    ]

            box = go.Figure()

            x = []
            y = []
            box_trace = []
            for model_id, acc in validations.items():
                x.append(model_id)
                y.append(numpy.mean([float(i) for i in acc]))
                if len(acc) >= 2:
                    box.add_trace(
                        go.Box(y=acc,
                               name=str(model_id),
                               marker_color="royalblue",
                               showlegend=False))

            rounds = list(range(len(y)))
            box.add_trace(go.Scatter(x=x, y=y, name='Mean'))

            box.update_xaxes(title_text='Model ID')
            box.update_yaxes(tickvals=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
            box.update_layout(
                title_text='Metric distribution over alliance participants: {}'
                .format(metric))
            box = json.dumps(box, cls=plotly.utils.PlotlyJSONEncoder)
            return box

        # @app.route('/seed')
        # def seed():
        #    try:
        #        result = self.inference.infer(request.args)
        #    except fedn.exceptions.ModelError:
        #        print("no model")
        #
        #    return result
        @app.route('/context', methods=['GET', 'POST'])
        @csrf.exempt  # TODO fix csrf token to form posting in package.py
        def context():
            # if self.control.state() != ReducerState.setup or self.control.state() != ReducerState.idle:
            #    return "Error, Context already assigned!"
            if request.method == 'POST':

                if 'file' not in request.files:
                    flash('No file part')
                    return redirect(request.url)

                file = request.files['file']
                # if user does not select file, browser also
                # submit an empty part without filename
                if file.filename == '':
                    flash('No selected file')
                    return redirect(request.url)

                if file and allowed_file(file.filename):
                    filename = secure_filename(file.filename)
                    file.save(
                        os.path.join(app.config['UPLOAD_FOLDER'], filename))

                    if self.control.state(
                    ) == ReducerState.instructing or self.control.state(
                    ) == ReducerState.monitoring:
                        return "Not allowed to change context while execution is ongoing."
                    self.current_compute_context = filename  # uploading new files will always set this to latest
                    self.control.set_compute_context(filename)
                    # return redirect(url_for('index',
                    #                        filename=filename))
                    return "success!"

            from flask import send_from_directory
            name = request.args.get('name', '')
            if name != '':
                return send_from_directory(app.config['UPLOAD_FOLDER'],
                                           name,
                                           as_attachment=True)
            if name == '' and self.current_compute_context:
                return send_from_directory(app.config['UPLOAD_FOLDER'],
                                           self.current_compute_context,
                                           as_attachment=True)

            return render_template('context.html')

        # import os, sys
        # self._original_stdout = sys.stdout
        # sys.stdout = open(os.devnull, 'w')
        if self.certificate:
            print("trying to connect with certs {} and key {}".format(
                str(self.certificate.cert_path),
                str(self.certificate.key_path)),
                  flush=True)
            app.run(host="0.0.0.0",
                    port="8090",
                    ssl_context=(str(self.certificate.cert_path),
                                 str(self.certificate.key_path)))
Beispiel #23
0
class AdminView(AdminMixin, ModelView):
    pass


class HomeAdminView(AdminMixin, AdminIndexView):
    pass


admin = Admin(app,
              'BTL Main Site',
              url='/',
              index_view=HomeAdminView(name='Home'))
app.config.from_object('btlstattracker.settings')
# Load environment specific settings
app.config.from_object('btlstattracker.local_settings')
csrf_protect.init_app(app)
bcrypt.init_app(app)

# #####################################
# ### DATABASE SETUP ##############
# ##########################

db = SQLAlchemy(app)
Migrate(app, db)

# #######################
# ### LOGIN CONFIGS
# login_manager = LoginManager()
#
# login_manager.init_app(app)
# login_manager.login_view = 'user.login'
Beispiel #24
0
class Server(Flask):
    def __init__(self, mode=None):
        static_folder = str(app.path / "eNMS" / "static")
        super().__init__(__name__, static_folder=static_folder)
        self.update_config(mode)
        self.register_extensions()
        self.configure_login_manager()
        self.configure_context_processor()
        self.configure_errors()
        self.configure_authentication()
        self.configure_routes()
        self.configure_rest_api()
        self.configure_cli()

    @staticmethod
    def catch_exceptions(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except LookupError as exc:
                rest_abort(404, message=str(exc))
            except Exception as exc:
                rest_abort(500, message=str(exc))

        return wrapper

    @staticmethod
    def monitor_requests(function):
        @wraps(function)
        def decorated_function(*args, **kwargs):
            if not current_user.is_authenticated:
                client_address = request.environ.get(
                    "HTTP_X_FORWARDED_FOR", request.environ["REMOTE_ADDR"])
                app.log(
                    "warning",
                    (f"Unauthorized {request.method} request from "
                     f"'{client_address}' calling the endpoint '{request.url}'"
                     ),
                )
                return redirect(url_for("blueprint.route", page="login"))
            else:
                forbidden_endpoints = app.rbac["groups"][
                    current_user.group]["GET"]
                if request.method == "GET" and request.path in forbidden_endpoints:
                    return render_template("error.html", error=403), 403
                return function(*args, **kwargs)

        return decorated_function

    def update_config(self, mode):
        mode = (mode or app.settings["app"]["config_mode"]).lower()
        self.config.update({
            "DEBUG":
            mode != "production",
            "SECRET_KEY":
            environ.get("SECRET_KEY", "get-a-real-key"),
            "WTF_CSRF_TIME_LIMIT":
            None,
            "ERROR_404_HELP":
            False,
            "MAX_CONTENT_LENGTH":
            20 * 1024 * 1024,
            "WTF_CSRF_ENABLED":
            mode != "test",
        })

    def register_extensions(self):
        self.auth = HTTPBasicAuth()
        self.csrf = CSRFProtect()
        self.csrf.init_app(self)

    def configure_login_manager(self):
        login_manager = LoginManager()
        login_manager.session_protection = "strong"
        login_manager.init_app(self)

        @login_manager.user_loader
        def user_loader(id):
            return db.fetch("user", allow_none=True, id=int(id))

        @login_manager.request_loader
        def request_loader(request):
            return db.fetch("user",
                            allow_none=True,
                            name=request.form.get("name"))

    def configure_context_processor(self):
        @self.context_processor
        def inject_properties():
            return {
                "property_types":
                property_types,
                "form_properties":
                form_properties,
                "menu":
                rbac["menu"],
                "names":
                app.property_names,
                "rbac":
                rbac["groups"][getattr(current_user, "group", "Read Only")],
                "relations":
                list(set(chain.from_iterable(relationships.values()))),
                "relationships":
                relationships,
                "service_types": {
                    service: service_class.pretty_name
                    for service, service_class in sorted(models.items())
                    if hasattr(service_class, "pretty_name")
                },
                "settings":
                app.settings,
                "table_properties":
                app.properties["tables"],
                "user":
                current_user.serialized
                if current_user.is_authenticated else None,
                "version":
                app.version,
            }

    def configure_errors(self):
        @self.errorhandler(403)
        def authorization_required(error):
            return render_template("error.html", error=403), 403

        @self.errorhandler(404)
        def not_found_error(error):
            return render_template("error.html", error=404), 404

    def configure_authentication(self):
        @self.auth.verify_password
        def verify_password(username, password):
            user = db.fetch("user", name=username)
            hash = app.settings["security"]["hash_user_passwords"]
            verify = argon2.verify if hash else str.__eq__
            return verify(password, app.get_password(user.password))

        @self.auth.get_password
        def get_password(username):
            return getattr(db.fetch("user", name=username), "password", False)

        @self.auth.error_handler
        def unauthorized():
            return make_response(jsonify({"message": "Wrong credentials."}),
                                 401)

    def configure_routes(self):
        blueprint = Blueprint("blueprint",
                              __name__,
                              template_folder="../templates")

        @blueprint.route("/")
        def site_root():
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/login", methods=["GET", "POST"])
        def login():
            if request.method == "POST":
                try:
                    user = app.authenticate_user(**request.form.to_dict())
                    if user:
                        login_user(user)
                        return redirect(
                            url_for("blueprint.route", page="dashboard"))
                    else:
                        abort(403)
                except Exception as exc:
                    info(f"Authentication failed ({str(exc)})")
                    abort(403)
            if not current_user.is_authenticated:
                login_form = LoginForm(request.form)
                authentication_methods = []
                if app.settings["ldap"]["active"]:
                    authentication_methods.append(("LDAP Domain", ) * 2)
                if app.settings["tacacs"]["active"]:
                    authentication_methods.append(("TACACS", ) * 2)
                authentication_methods.append(("Local User", ) * 2)
                login_form.authentication_method.choices = authentication_methods
                return render_template("login.html", login_form=login_form)
            return redirect(url_for("blueprint.route", page="dashboard"))

        @blueprint.route("/dashboard")
        @self.monitor_requests
        def dashboard():
            return render_template(
                f"dashboard.html",
                **{
                    "endpoint": "dashboard",
                    "properties": properties["dashboard"]
                },
            )

        @blueprint.route("/logout")
        @self.monitor_requests
        def logout():
            logout_user()
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/table/<table_type>")
        @self.monitor_requests
        def table(table_type):
            return render_template(
                f"table.html", **{
                    "endpoint": f"table/{table_type}",
                    "type": table_type
                })

        @blueprint.route("/view/<view_type>")
        @self.monitor_requests
        def view(view_type):
            return render_template(
                f"visualization.html", **{
                    "endpoint": "view",
                    "view_type": view_type
                })

        @blueprint.route("/workflow_builder")
        @self.monitor_requests
        def workflow_builder():
            return render_template(f"workflow.html",
                                   endpoint="workflow_builder")

        @blueprint.route("/form/<form_type>")
        @self.monitor_requests
        def form(form_type):
            return render_template(
                f"forms/{form_templates.get(form_type, 'base')}.html",
                **{
                    "endpoint": f"forms/{form_type}",
                    "action": form_actions.get(form_type),
                    "form": form_classes[form_type](request.form),
                    "form_type": form_type,
                },
            )

        @blueprint.route("/help/<path:path>")
        @self.monitor_requests
        def help(path):
            return render_template(f"help/{path}.html")

        @blueprint.route("/view_service_results/<int:id>")
        @self.monitor_requests
        def view_service_results(id):
            result = db.fetch("run", id=id).result().result
            return f"<pre>{app.str_dict(result)}</pre>"

        @blueprint.route("/download_file/<path:path>")
        @self.monitor_requests
        def download_file(path):
            return send_file(f"/{path}", as_attachment=True)

        @blueprint.route("/<path:_>")
        @self.monitor_requests
        def get_requests_sink(_):
            abort(404)

        @blueprint.route("/", methods=["POST"])
        @blueprint.route("/<path:page>", methods=["POST"])
        @self.monitor_requests
        def route(page):
            endpoint, *args = page.split("/")
            if f"/{endpoint}" not in app.rbac["endpoints"]["POST"]:
                return jsonify({"alert": "Invalid POST request."})
            if f"/{endpoint}" in app.rbac["groups"][
                    current_user.group]["POST"]:
                return jsonify({"alert": "Error 403 Forbidden."})
            form_type = request.form.get("form_type")
            if endpoint in app.json_endpoints:
                result = getattr(app, endpoint)(*args, **request.json)
            elif form_type:
                form = form_classes[form_type](request.form)
                if not form.validate_on_submit():
                    return jsonify({
                        "invalid_form": True,
                        **{
                            "errors": form.errors
                        }
                    })
                result = getattr(app, endpoint)(*args, **form_postprocessing(
                    form, request.form))
            else:
                result = getattr(app, endpoint)(*args, **request.form)
            try:
                db.session.commit()
                return jsonify(result)
            except Exception as exc:
                raise exc
                db.session.rollback()
                if app.settings["app"]["config_mode"] == "debug":
                    raise
                match = search("UNIQUE constraint failed: (\w+).(\w+)",
                               str(exc))
                if match:
                    result = (f"There already is a {match.group(1)} "
                              f"with the same {match.group(2)}.")
                else:
                    result = str(exc)
                return jsonify({"alert": result})

        self.register_blueprint(blueprint)

    def configure_cli(self):
        @self.cli.command(name="run_service")
        @argument("name")
        @option("--devices")
        @option("--payload")
        def start(name, devices, payload):
            devices_list = devices.split(",") if devices else []
            devices_list = [
                db.fetch("device", name=name).id for name in devices_list
            ]
            payload_dict = loads(payload) if payload else {}
            payload_dict["devices"] = devices_list
            service = db.fetch("service", name=name)
            results = app.run(service.id, **payload_dict)
            db.session.commit()
            echo(app.str_dict(results))

        @self.cli.command(name="delete-changelog")
        @option(
            "--keep-last-days",
            default=15,
            help="Number of days to keep in the changelog",
        )
        def remove_changelog(keep_last_days):
            deletion_time = datetime.now() - timedelta(days=keep_last_days)
            app.result_log_deletion(
                date_time=deletion_time.strftime("%d/%m/%Y %H:%M:%S"),
                deletion_types=["changelog"],
            )
            app.log("info", f"deleted all changelogs up until {deletion_time}")

    def configure_rest_api(self):

        api = Api(self, decorators=[self.csrf.exempt])

        class CreatePool(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self):
                data = request.get_json(force=True)
                db.factory(
                    "pool",
                    **{
                        "name":
                        data["name"],
                        "devices": [
                            db.fetch("device", name=name).id
                            for name in data.get("devices", "")
                        ],
                        "links": [
                            db.fetch("link", name=name).id
                            for name in data.get("links", "")
                        ],
                        "manually_defined":
                        True,
                    },
                )
                db.session.commit()
                return data

        class Heartbeat(Resource):
            def get(self):
                return {
                    "name": getnode(),
                    "cluster_id": app.settings["cluster"]["id"],
                }

        class Query(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def get(self, cls):
                results = db.fetch(cls,
                                   all_matches=True,
                                   **request.args.to_dict())
                return [
                    result.get_properties(exclude=["positions"])
                    for result in results
                ]

        class GetInstance(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def get(self, cls, name):
                return db.fetch(cls,
                                name=name).to_dict(relation_names_only=True,
                                                   exclude=["positions"])

            def delete(self, cls, name):
                result = db.delete(cls, name=name)
                db.session.commit()
                return result

        class GetConfiguration(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def get(self, name):
                return db.fetch("device", name=name).configuration

        class GetResult(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def get(self, name, runtime):
                service = db.fetch("service", name=name)
                run = db.fetch("run", service_id=service.id, runtime=runtime)
                return run.result().result

        class UpdateInstance(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self, cls):
                data = request.get_json(force=True)
                object_data = app.objectify(cls, data)
                result = db.factory(cls, **object_data).serialized
                db.session.commit()
                return result

        class Migrate(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self, direction):
                kwargs = request.get_json(force=True)
                return getattr(app, f"migration_{direction}")(**kwargs)

        class RunService(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self):
                data = {
                    "trigger": "REST",
                    "creator": request.authorization["username"],
                    **request.get_json(force=True),
                }
                errors, devices, pools = [], [], []
                service = db.fetch("service", name=data["name"])
                handle_asynchronously = data.get("async", False)
                for device_name in data.get("devices", ""):
                    device = db.fetch("device", name=device_name)
                    if device:
                        devices.append(device.id)
                    else:
                        errors.append(
                            f"No device with the name '{device_name}'")
                for device_ip in data.get("ip_addresses", ""):
                    device = db.fetch("device", ip_address=device_ip)
                    if device:
                        devices.append(device.id)
                    else:
                        errors.append(
                            f"No device with the IP address '{device_ip}'")
                for pool_name in data.get("pools", ""):
                    pool = db.fetch("pool", name=pool_name)
                    if pool:
                        pools.append(pool.id)
                    else:
                        errors.append(f"No pool with the name '{pool_name}'")
                if errors:
                    return {"errors": errors}
                if devices or pools:
                    data.update({"devices": devices, "pools": pools})
                data["runtime"] = runtime = app.get_time()
                if handle_asynchronously:
                    app.scheduler.add_job(
                        id=runtime,
                        func=app.run,
                        run_date=datetime.now(),
                        args=[service.id],
                        kwargs=data,
                        trigger="date",
                    )
                    return {"errors": errors, "runtime": runtime}
                else:
                    return {**app.run(service.id, **data), "errors": errors}

        class Topology(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self, direction):
                if direction == "import":
                    return app.import_topology(
                        **{
                            "replace": request.form["replace"] == "True",
                            "file": request.files["file"],
                        })
                else:
                    app.export_topology(**request.get_json(force=True))
                    return "Topology Export successfully executed."

        class Search(Resource):
            decorators = [self.auth.login_required, self.catch_exceptions]

            def post(self):
                rest_body = request.get_json(force=True)
                kwargs = {
                    "draw":
                    1,
                    "columns": [{
                        "data": column
                    } for column in rest_body["columns"]],
                    "order": [{
                        "column": 0,
                        "dir": "asc"
                    }],
                    "start":
                    0,
                    "length":
                    rest_body["maximum_return_records"],
                    "form":
                    rest_body["search_criteria"],
                    "rest_api_request":
                    True,
                }
                return app.filtering(rest_body["type"], **kwargs)["data"]

        class Sink(Resource):
            def get(self, **_):
                rest_abort(
                    404,
                    message=
                    f"The requested {request.method} endpoint does not exist.",
                )

            post = put = patch = delete = get

        for endpoint in app.rest_endpoints:

            def post(_, ep=endpoint):
                getattr(app, ep)()
                db.session.commit()
                return f"Endpoint {ep} successfully executed."

            api.add_resource(
                type(
                    endpoint,
                    (Resource, ),
                    {
                        "decorators":
                        [self.auth.login_required, self.catch_exceptions],
                        "post":
                        post,
                    },
                ),
                f"/rest/{endpoint}",
            )
        api.add_resource(CreatePool, "/rest/create_pool")
        api.add_resource(Heartbeat, "/rest/is_alive")
        api.add_resource(RunService, "/rest/run_service")
        api.add_resource(Query, "/rest/query/<string:cls>")
        api.add_resource(UpdateInstance, "/rest/instance/<string:cls>")
        api.add_resource(GetInstance,
                         "/rest/instance/<string:cls>/<string:name>")
        api.add_resource(GetConfiguration, "/rest/configuration/<string:name>")
        api.add_resource(Search, "/rest/search")
        api.add_resource(GetResult,
                         "/rest/result/<string:name>/<string:runtime>")
        api.add_resource(Migrate, "/rest/migrate/<string:direction>")
        api.add_resource(Topology, "/rest/topology/<string:direction>")
        api.add_resource(Sink, "/rest/<path:path>")
Beispiel #25
0
# from concurrent.futures import ThreadPoolExecutor
from common.permission import PermissionManager
from redis import ConnectionPool
from typing import Set, NoReturn
import os
import celery
from flask_cors import CORS
web_app = flask.Flask("HelloJudge2")
web_app.config["SQLALCHEMY_DATABASE_URI"] = config.DATABASE_URI
web_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7)
web_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

web_app.secret_key = config.SESSION_KEY
csrf = CSRFProtect()
if config.ENABLE_CSRF_TOKEN:
    csrf.init_app(web_app)
if config.DEBUG:
    import logging
    logging.getLogger('flask_cors').level = logging.DEBUG
db = SQLAlchemy(web_app)
CORS(web_app,
     supports_credentials=True,
     resources={".*": {
         "origins": "*",
         "supports_credentials": True
     }})
basedir = os.path.dirname(__file__)
logger = web_app.logger
socket = SocketIO(web_app)
queue = celery.Celery(web_app.name, broker=config.REDIS_URI)
remote_judge_queue = celery.Celery(web_app.name,
Beispiel #26
0
def setup_csrf():
    csrf = CSRFProtect()
    app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
    csrf.init_app(app)
    return True
Beispiel #27
0
class Server(Flask):
    def __init__(self, mode=None):
        static_folder = str(app.path / "eNMS" / "static")
        super().__init__(__name__, static_folder=static_folder)
        self.update_config(mode)
        self.register_extensions()
        self.register_plugins()
        self.configure_login_manager()
        self.configure_context_processor()
        self.configure_errors()
        self.configure_authentication()
        self.configure_routes()
        self.configure_rest_api()

    @staticmethod
    def monitor_requests(function):
        @wraps(function)
        def decorated_function(*args, **kwargs):
            if not current_user.is_authenticated:
                client_address = request.environ.get(
                    "HTTP_X_FORWARDED_FOR", request.environ["REMOTE_ADDR"])
                app.log(
                    "warning",
                    (f"Unauthorized {request.method} request from "
                     f"'{client_address}' calling the endpoint '{request.url}'"
                     ),
                )
                return redirect(url_for("blueprint.route", page="login"))
            else:
                if (not current_user.is_admin and request.method == "GET"
                        and request.path not in current_user.get_requests):
                    return render_template("error.html", error=403), 403
                return function(*args, **kwargs)

        return decorated_function

    def update_config(self, mode):
        mode = (mode or app.settings["app"]["config_mode"]).lower()
        self.config.update({
            "DEBUG":
            mode != "production",
            "SECRET_KEY":
            getenv("SECRET_KEY", "secret_key"),
            "WTF_CSRF_TIME_LIMIT":
            None,
            "ERROR_404_HELP":
            False,
            "MAX_CONTENT_LENGTH":
            20 * 1024 * 1024,
            "WTF_CSRF_ENABLED":
            mode != "test",
            "PERMANENT_SESSION_LIFETIME":
            timedelta(minutes=app.settings["app"]["session_timeout_minutes"]),
        })

    def register_plugins(self):
        for plugin_path in Path(app.settings["app"]["plugin_path"]).iterdir():
            if not Path(plugin_path / "settings.json").exists():
                continue
            try:
                with open(plugin_path / "settings.json", "r") as file:
                    settings = load(file)
                if not settings["active"]:
                    continue
                module = import_module(f"eNMS.plugins.{plugin_path.stem}")
                module.Plugin(self, app, db, **settings)
                for setup_file in ("database", "properties", "rbac"):
                    update_file(getattr(app, setup_file),
                                settings.get(setup_file, {}))
            except Exception as exc:
                app.log("error",
                        f"Could not load plugin '{plugin_path.stem}' ({exc})")
                continue
            app.log("info", f"Loading plugin: {settings['name']}")
        init_variable_forms(app)
        db.base.metadata.create_all(bind=db.engine)

    def register_extensions(self):
        self.auth = HTTPBasicAuth()
        self.csrf = CSRFProtect()
        self.csrf.init_app(self)

    def configure_login_manager(self):
        login_manager = LoginManager()
        login_manager.session_protection = "strong"
        login_manager.init_app(self)

        @login_manager.user_loader
        def user_loader(name):
            return db.fetch("user", allow_none=True, name=name)

        @login_manager.request_loader
        def request_loader(request):
            return db.fetch("user",
                            allow_none=True,
                            name=request.form.get("name"))

    def configure_context_processor(self):
        @self.context_processor
        def inject_properties():
            return {
                "configuration_properties":
                app.configuration_properties,
                "form_properties":
                form_properties,
                "menu":
                rbac["menu"],
                "names":
                app.property_names,
                "property_types":
                property_types,
                "relations":
                list(set(chain.from_iterable(relationships.values()))),
                "relationships":
                relationships,
                "service_types": {
                    service: service_class.pretty_name
                    for service, service_class in sorted(models.items())
                    if hasattr(service_class, "pretty_name")
                },
                "settings":
                app.settings,
                "themes":
                themes,
                "table_properties":
                app.properties["tables"],
                "user":
                current_user.serialized
                if current_user.is_authenticated else None,
                "version":
                app.version,
                "visualization":
                visualization,
            }

    def configure_errors(self):
        @self.errorhandler(403)
        def authorization_required(error):
            return render_template("error.html", error=403), 403

        @self.errorhandler(404)
        def not_found_error(error):
            return render_template("error.html", error=404), 404

    def configure_authentication(self):
        @self.auth.verify_password
        def verify_password(username, password):
            user = app.authenticate_user(name=username, password=password)
            if user:
                request_type = f"{request.method.lower()}_requests"
                endpoint = "/".join(request.path.split("/")[:3])
                if user.is_admin or endpoint in getattr(
                        user, request_type, []):
                    login_user(user)
                    return True
                g.status = 403
            else:
                g.status = 401

        @self.auth.get_password
        def get_password(username):
            user = db.fetch("user", allow_none=True, name=username)
            return getattr(user, "password", False)

        @self.auth.error_handler
        def unauthorized():
            message = f"{'Wrong' if g.status == 401 else 'Insufficient'} credentials"
            return make_response(jsonify({"message": message}), g.status)

    def configure_routes(self):
        blueprint = Blueprint("blueprint",
                              __name__,
                              template_folder="../templates")

        @blueprint.route("/")
        def site_root():
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/login", methods=["GET", "POST"])
        def login():
            if request.method == "POST":
                kwargs, success = request.form.to_dict(), False
                username = kwargs["name"]
                try:
                    user = app.authenticate_user(**kwargs)
                    if user:
                        login_user(user, remember=False)
                        session.permanent = True
                        success, log = True, f"User '{username}' logged in"
                    else:
                        log = f"Authentication failed for user '{username}'"
                except Exception as exc:
                    log = f"Authentication error for user '{username}' ({exc})"
                finally:
                    app.log("info" if success else "warning",
                            log,
                            logger="security")
                    if success:
                        return redirect(
                            url_for("blueprint.route", page="dashboard"))
                    else:
                        abort(403)
            if not current_user.is_authenticated:
                login_form = LoginForm(request.form)
                methods = app.settings["authentication"]["methods"].items()
                login_form.authentication_method.choices = [
                    (method, properties["display_name"])
                    for method, properties in methods if properties["enabled"]
                ]
                return render_template("login.html", login_form=login_form)
            return redirect(url_for("blueprint.route", page="dashboard"))

        @blueprint.route("/dashboard")
        @self.monitor_requests
        def dashboard():
            return render_template(
                "dashboard.html",
                **{
                    "endpoint": "dashboard",
                    "properties": properties["dashboard"]
                },
            )

        @blueprint.route("/logout")
        @self.monitor_requests
        def logout():
            logout_log = f"User '{current_user.name}' logging out"
            app.log("info", logout_log, logger="security")
            logout_user()
            return redirect(url_for("blueprint.route", page="login"))

        @blueprint.route("/table/<table_type>")
        @self.monitor_requests
        def table(table_type):
            return render_template(
                "table.html", **{
                    "endpoint": f"table/{table_type}",
                    "type": table_type
                })

        @blueprint.route("/visualization/<view_type>")
        @self.monitor_requests
        def view(view_type):
            return render_template(
                "visualization.html",
                endpoint=view_type,
                default_pools=app.get_visualization_parameters(),
            )

        @blueprint.route("/workflow_builder")
        @self.monitor_requests
        def workflow_builder():
            return render_template("workflow.html",
                                   endpoint="workflow_builder")

        @blueprint.route("/form/<form_type>")
        @self.monitor_requests
        def form(form_type):
            form = form_classes[form_type](request.form)
            return render_template(
                f"forms/{getattr(form, 'template', 'base')}.html",
                **{
                    "endpoint": f"forms/{form_type}",
                    "action": getattr(form, "action", None),
                    "button_label": getattr(form, "button_label", "Confirm"),
                    "button_class": getattr(form, "button_class", "success"),
                    "form": form,
                    "form_type": form_type,
                },
            )

        @blueprint.route("/help/<path:path>")
        @self.monitor_requests
        def help(path):
            return render_template(f"help/{path}.html")

        @blueprint.route("/view_service_results/<int:id>")
        @self.monitor_requests
        def view_service_results(id):
            result = db.fetch("run", id=id).result().result
            return f"<pre>{app.str_dict(result)}</pre>"

        @blueprint.route("/download_file/<path:path>")
        @self.monitor_requests
        def download_file(path):
            return send_file(f"/{path}", as_attachment=True)

        @blueprint.route("/export_service/<int:id>")
        @self.monitor_requests
        def export_service(id):
            return send_file(f"/{app.export_service(id)}.tgz",
                             as_attachment=True)

        @blueprint.route("/<path:_>")
        @self.monitor_requests
        def get_requests_sink(_):
            abort(404)

        @blueprint.route("/", methods=["POST"])
        @blueprint.route("/<path:page>", methods=["POST"])
        @self.monitor_requests
        def route(page):
            endpoint, *args = page.split("/")
            admin_user = current_user.is_admin
            if f"/{endpoint}" not in app.rbac["post_requests"]:
                return jsonify({"alert": "Invalid POST request."})
            if not admin_user and f"/{endpoint}" not in current_user.post_requests:
                return jsonify({"alert": "Error 403 - Operation not allowed."})
            form_type = request.form.get("form_type")
            if request.json:
                kwargs = request.json
            elif form_type:
                form = form_classes[form_type](request.form)
                if not form.validate_on_submit():
                    return jsonify({
                        "invalid_form": True,
                        **{
                            "errors": form.errors
                        }
                    })
                kwargs = form.form_postprocessing(request.form)
            else:
                kwargs = request.form
            try:
                with db.session_scope():
                    result = getattr(app, endpoint)(*args, **kwargs)
            except db.rbac_error:
                result = {"alert": "Error 403 - Operation not allowed."}
            except Exception:
                app.log("error", format_exc(), change_log=False)
                result = {"alert": "Error 500 - Internal Server Error"}
            return jsonify(result)

        self.register_blueprint(blueprint)

    @staticmethod
    def monitor_rest_request(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for index in range(db.retry_commit_number):
                try:
                    result = func(*args, **kwargs)
                except db.rbac_error as exc:
                    return rest_abort(404, message=str(exc))
                except Exception as exc:
                    return rest_abort(500, message=str(exc))
                try:
                    db.session.commit()
                    return result
                except Exception as exc:
                    db.session.rollback()
                    app.log("error", f"Rest Call n°{index} failed ({exc}).")
                    stacktrace = format_exc()
                    sleep(db.retry_commit_time * (index + 1))
            else:
                rest_abort(500, message=stacktrace)

        return wrapper

    def configure_rest_api(self):

        api = Api(self, decorators=[self.csrf.exempt])

        class Heartbeat(Resource):
            def get(self):
                return {
                    "name": getnode(),
                    "cluster_id": app.settings["cluster"]["id"],
                }

        class Query(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def get(self, model):
                properties = request.args.to_dict()
                results = db.fetch(model, all_matches=True, **properties)
                return [
                    result.get_properties(exclude=["positions"])
                    for result in results
                ]

        class GetInstance(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def get(self, model, name):
                return db.fetch(model,
                                name=name).to_dict(relation_names_only=True,
                                                   exclude=["positions"])

            def delete(self, model, name):
                result = db.delete(model, name=name)
                return result

        class GetConfiguration(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def get(self, name):
                return db.fetch("device", name=name).configuration

        class GetResult(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def get(self, name, runtime):
                run = db.fetch("run",
                               service_name=name,
                               runtime=runtime,
                               allow_none=True)
                if not run:
                    error_message = (
                        "There are no results or on-going services "
                        "for the requested service and runtime.")
                    return {"error": error_message}
                else:
                    result = run.result()
                    return {
                        "status": run.status,
                        "result":
                        result.result if result else "No results yet.",
                    }

        class UpdateInstance(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self, model):
                data, result = request.get_json(force=True), defaultdict(list)
                if not isinstance(data, list):
                    data = [data]
                for instance in data:
                    if "name" not in instance:
                        result["failure"].append((instance, "Name is missing"))
                        continue
                    try:
                        object_data = app.objectify(model, instance)
                        object_data["update_pools"] = model in properties[
                            "filtering"]
                        instance = db.factory(model, **object_data)
                        result["success"].append(instance.name)
                    except Exception:
                        result["failure"].append((instance, format_exc()))
                return result

        class Migrate(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self, direction):
                kwargs = request.get_json(force=True)
                return getattr(app, f"migration_{direction}")(**kwargs)

        class RunService(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self):
                data = {
                    "trigger": "REST",
                    "creator": request.authorization["username"],
                    **request.get_json(force=True),
                }
                errors, devices, pools = [], [], []
                service = db.fetch("service", name=data["name"], rbac="run")
                handle_asynchronously = data.get("async", False)
                for device_name in data.get("devices", ""):
                    device = db.fetch("device", name=device_name)
                    if device:
                        devices.append(device.id)
                    else:
                        errors.append(
                            f"No device with the name '{device_name}'")
                for device_ip in data.get("ip_addresses", ""):
                    device = db.fetch("device", ip_address=device_ip)
                    if device:
                        devices.append(device.id)
                    else:
                        errors.append(
                            f"No device with the IP address '{device_ip}'")
                for pool_name in data.get("pools", ""):
                    pool = db.fetch("pool", name=pool_name)
                    if pool:
                        pools.append(pool.id)
                    else:
                        errors.append(f"No pool with the name '{pool_name}'")
                if errors:
                    return {"errors": errors}
                if devices or pools:
                    data.update({"devices": devices, "pools": pools})
                data["runtime"] = runtime = app.get_time()
                if handle_asynchronously:
                    Thread(target=app.run, args=(service.id, ),
                           kwargs=data).start()
                    return {"errors": errors, "runtime": runtime}
                else:
                    return {**app.run(service.id, **data), "errors": errors}

        class RunTask(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self):
                task = db.fetch("task", rbac="schedule", id=request.get_json())
                data = {
                    "trigger": "Scheduler",
                    "creator": task.last_scheduled_by,
                    "runtime": app.get_time(),
                    "task": task.id,
                    **task.initial_payload,
                }
                if task.devices:
                    data["devices"] = [device.id for device in task.devices]
                if task.pools:
                    data["pools"] = [pool.id for pool in task.pools]
                Thread(target=app.run, args=(task.service.id, ),
                       kwargs=data).start()

        class Topology(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self, direction):
                if direction == "import":
                    result = app.import_topology(
                        **{
                            "replace": request.form["replace"] == "True",
                            "file": request.files["file"],
                        })
                    status = 206 if "Partial" in result else 200
                    return result, status
                else:
                    app.export_topology(**request.get_json(force=True))
                    return "Topology Export successfully executed."

        class Search(Resource):
            decorators = [self.auth.login_required, self.monitor_rest_request]

            def post(self):
                rest_body = request.get_json(force=True)
                kwargs = {
                    "draw":
                    1,
                    "columns": [{
                        "data": column
                    } for column in rest_body["columns"]],
                    "order": [{
                        "column": 0,
                        "dir": "asc"
                    }],
                    "start":
                    0,
                    "length":
                    rest_body.get("maximum_return_records", 10),
                    "form":
                    rest_body.get("search_criteria", {}),
                    "rest_api_request":
                    True,
                }
                return app.filtering(rest_body["type"], **kwargs)["data"]

        class Sink(Resource):
            def get(self, **_):
                rest_abort(
                    404,
                    message=
                    f"The requested {request.method} endpoint does not exist.",
                )

            post = put = patch = delete = get

        for endpoint in app.rest_endpoints:

            def post(_, ep=endpoint):
                getattr(app, ep)()
                return f"Endpoint {ep} successfully executed."

            api.add_resource(
                type(
                    endpoint,
                    (Resource, ),
                    {
                        "decorators": [
                            self.auth.login_required,
                            self.monitor_rest_request,
                        ],
                        "post":
                        post,
                    },
                ),
                f"/rest/{endpoint}",
            )

        api.add_resource(Heartbeat, "/rest/is_alive")
        api.add_resource(RunService, "/rest/run_service")
        api.add_resource(RunTask, "/rest/run_task")
        api.add_resource(Query, "/rest/query/<string:model>")
        api.add_resource(UpdateInstance, "/rest/instance/<string:model>")
        api.add_resource(GetInstance,
                         "/rest/instance/<string:model>/<string:name>")
        api.add_resource(GetConfiguration, "/rest/configuration/<string:name>")
        api.add_resource(Search, "/rest/search")
        api.add_resource(GetResult,
                         "/rest/result/<string:name>/<string:runtime>")
        api.add_resource(Migrate, "/rest/migrate/<string:direction>")
        api.add_resource(Topology, "/rest/topology/<string:direction>")
        api.add_resource(Sink, "/rest/<path:path>")
Beispiel #28
0
    'processpool': ProcessPoolExecutor(10)
}
job_defaults = {
    'coalesce': False,
    'max_instances': 10
}
scheduler = BackgroundScheduler(jobstores=jobstores, executors=executors, job_defaults=job_defaults, timezone=utc)
scheduler.start()

rms.jinja_env.globals['momentjs'] = momentjs
db = SQLAlchemy(rms)
Session = sessionmaker()
# cache = Cache(rms)
csrf = CSRFProtect(rms)
mail = Mail(rms)
csrf.init_app(rms)
db.init_app(rms)

perms = Permissions(rms, db, current_user)
login_manager = LoginManager()
login_manager.init_app(rms)

#auth
#http://pythonhosted.org/Flask-Principal/
#http://flask-restful-cn.readthedocs.org/en/0.3.4/quickstart.html

def setup_logger(logger_name, log_file, level=logging.INFO):
    l = logging.getLogger(logger_name)
    formatter = logging.Formatter('%(asctime)s : %(message)s')
    fileHandler = logging.FileHandler(log_file, mode='a')
    fileHandler.setFormatter(formatter)
Beispiel #29
0
import uuid
import random
import string
from flask_wtf.csrf import CSRFProtect

csrf = CSRFProtect()

db = SQLAlchemy()
# After defining `db`, import auth models due to
# circular dependency.
from mhn.auth.models import User, Role, ApiKey
user_datastore = SQLAlchemyUserDatastore(db, User, Role)

mhn = Flask(__name__)
mhn.config.from_object('config')
csrf.init_app(mhn)

# Email app setup.
mail = Mail()
mail.init_app(mhn)

# Registering app on db instance.
db.init_app(mhn)

# Setup flask-security for auth.
Security(mhn, user_datastore)

# Registering blueprints.
from mhn.api.views import api
mhn.register_blueprint(api)
Beispiel #30
0
def create_app(test_config=None):
    app = Flask(__name__, instance_relative_config=True)

    app.config.from_object("config")

    if test_config is None:
        # load the instance config, if it exists, when not testing
        app.config.from_pyfile("config.py", silent=True)
    else:
        # load the test config if passed in
        app.config.update(test_config)

    # Connect the db
    db.init_app(app)

    # Set up flask-login
    login_manager = LoginManager()
    login_manager.init_app(app)

    # Set up flask-principal for roles management
    principals = Principal()
    principals.init_app(app)

    # Set up csrf protection
    csrf = CSRFProtect()
    csrf.init_app(app)
    csrf.exempt(login_blueprint)

    app.register_blueprint(login_blueprint)
    app.register_blueprint(admin_blueprint, url_prefix="/admin")
    app.register_blueprint(token_management_blueprint, url_prefix="/user")
    app.register_blueprint(aggregation_unit_blueprint,
                           url_prefix="/spatial_aggregation")

    @app.after_request
    def set_xsrf_cookie(response):
        """
        Sets the csrf token used by csrf protect as a cookie to allow usage with
        react.
        """
        response.set_cookie("X-CSRF", generate_csrf())
        return response

    @app.errorhandler(CSRFError)
    def handle_csrf_error(e):
        """
        CSRF errors are interpreted as an access denied.
        """
        return "CSRF error", 401

    @app.errorhandler(InvalidUsage)
    def handle_invalid_usage(error):
        response = flask.jsonify(error.to_dict())
        response.status_code = error.status_code
        return response

    @app.before_request
    def before_request():
        """
        Make sessions expire after 20 minutes of inactivity.
        """
        flask.session.permanent = True
        app.permanent_session_lifetime = datetime.timedelta(minutes=20)
        flask.session.modified = True
        flask.g.user = flask_login.current_user

    @login_manager.user_loader
    def load_user(userid):
        """Helper for flask-login."""
        return User.query.filter(User.id == userid).first()

    @identity_loaded.connect_via(app)
    def on_identity_loaded(sender, identity):
        """Helper for flask-principal."""
        # Set the identity user object
        identity.user = current_user

        # Add the UserNeed to the identity
        if hasattr(current_user, "id"):
            identity.provides.add(UserNeed(current_user.id))

        try:
            if current_user.is_admin:
                identity.provides.add(RoleNeed("admin"))
        except AttributeError:
            pass  # Definitely not an admin

    @app.cli.command("get-fernet")
    def make_fernet_key():
        """
        Generate a new Fernet key for symmetric encryption of data at
        rest.
        """
        print(f'FERNET_KEY="{Fernet.generate_key().decode()}"')

    # Add flask <command> CLI commands
    app.cli.add_command(demodata)
    app.cli.add_command(init_db_command)
    app.cli.add_command(add_admin)
    return app
def init_app(app):
    """
    CloudAlbum application initializer
    :param app: Flask.app
    :return: initialized application
    """
    csrf = CSRFProtect()
    csrf.init_app(app)

    # Regist error handler
    app.register_error_handler(404, errorHandler.not_found)
    app.register_error_handler(405, errorHandler.server_error)
    app.register_error_handler(500, errorHandler.server_error)
    app.register_error_handler(400, errorHandler.csrf_error)

    # CSRF setup for Flask Blueprint module
    userView.blueprint.before_request(csrf.protect)
    siteView.blueprint.before_request(csrf.protect)

    # Regist Flask Blueprint module
    app.register_blueprint(siteView.blueprint, url_prefix='/')
    app.register_blueprint(userView.blueprint, url_prefix='/users')
    app.register_blueprint(photoView.blueprint, url_prefix='/photos')

    # Setup application configuration
    app.secret_key = conf['FLASK_SECRET']
    app.config['SQLALCHEMY_DATABASE_URI'] = conf['DB_URL']
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = conf[
        'SQLALCHEMY_TRACK_MODIFICATIONS']
    app.config['SQLALCHEMY_ECHO'] = conf['DB_ECHO_FLAG']

    # SQLITE doesn't support DB connection pool
    if 'sqlite' not in conf['DB_URL'].lower():
        app.config['SQLALCHEMY_POOL_SIZE'] = conf['DB_POOL_SIZE']
        app.config['SQLALCHEMY_MAX_OVERFLOW'] = conf['DB_MAX_OVERFLOW']
        app.config['SQLALCHEMY_POOL_TIMEOUT'] = conf[
            'DB_SQLALCHEMY_POOL_TIMEOUT']
        app.config['SQLALCHEMY_POOL_RECYCLE'] = conf[
            'DB_SQLALCHEMY_POOL_RECYCLE']

    app.jinja_env.globals['url_for_other_page'] = url_for_other_page

    # Logger setup
    app.config['LOGGING_LEVEL'] = get_log_level()
    app.config['LOGGING_FORMAT'] = conf['LOGGING_FORMAT']
    app.config['LOGGING_LOCATION'] = conf['LOG_FILE_PATH']
    app.config['LOGGING_FILENAME'] = os.path.join(conf['LOG_FILE_PATH'],
                                                  conf['LOG_FILE_NAME'])
    app.config['LOGGING_MAX_BYTES'] = conf['LOGGING_MAX_BYTES']
    app.config['LOGGING_BACKUP_COUNT'] = conf['LOGGING_BACKUP_COUNT']

    util.log_path_check(conf['LOG_FILE_PATH'])
    file_handler = RotatingFileHandler(
        app.config['LOGGING_FILENAME'],
        maxBytes=app.config['LOGGING_MAX_BYTES'],
        backupCount=app.config['LOGGING_BACKUP_COUNT'])
    file_handler.setFormatter(Formatter(app.config['LOGGING_FORMAT']))
    file_handler.setLevel(app.config['LOGGING_LEVEL'])

    app.logger.addHandler(file_handler)
    app.logger.setLevel(app.config['LOGGING_LEVEL'])
    app.logger.info("logging start")

    # Setup LoginManager
    login.init_app(app)
    login.login_view = 'userView.signin'

    # Setup models for DB operations
    with app.app_context():
        models.db.init_app(app)
        try:
            models.db.create_all()
        except Exception as e:
            app.logger.error(e)
            exit(-1)
    return app
Beispiel #32
0
def create_app(test_config=None):

    # Create the app object and import some config

    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(DATABASE=os.path.join(app.instance_path,
                                                  'db.sqlite'), )
    app.config.from_pyfile("config.py")

    if test_config is None:
        app.config.from_pyfile("config.py", silent=True)
    else:
        app.config.from_mapping(test_config)

    # Make sure that the instance directory exists
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass

    # Load middleware and commands

    from flask_wtf.csrf import CSRFProtect
    csrf = CSRFProtect()
    csrf.init_app(app)

    from . import db
    db.init_app(app)

    from . import user
    user.init_app(app)

    from . import mail
    mail.init_app(app)

    from . import allocations
    allocations.init_app(app)

    from . import review
    review.init_app(app)

    from . import survey
    survey.init_app(app)

    # Raven authentication

    # Request class boilerplate adapted from python-ucam-webauth
    # documentation, required so that we can make a hostname
    # whitelist

    class R(flask.Request):
        trusted_hosts = app.config["TRUSTED_HOSTS"]

    app.request_class = R
    auth_decorator = AuthDecorator(desc="Downing JCR Room Ballot Survey")

    # Create the routes

    from roomsurvey.syndicate import get_syndicate_for_user, get_syndicate_invitations, update_invitation, create_syndicate
    from roomsurvey.user import get_user, get_year, is_syndicatable
    from roomsurvey.survey import get_survey_data, import_survey_data, log_survey_data
    from roomsurvey.allocations import get_allocation_for_user
    from roomsurvey.review import has_reviewed, check_review, write_review

    @app.before_request
    def before_request_handler():
        g.crsid = auth_decorator.principal

        # fullname is ONLY set if the user is both authenticated AND in the database
        g.fullname = None

        if g.crsid:
            g.fullname = get_user(g.crsid)
            g.user_year = get_year(g.crsid)

        if (g.crsid and request.path != "/logout"
                and not request.path.startswith("/static") and not g.fullname):
            return render_template("unauthorised.html")

        g.current_time = int(time.time())

    @app.route("/dashboard")
    @auth_decorator
    def dashboard():
        return render_template("dashboard.html",
                               syndicate=get_syndicate_for_user(g.crsid),
                               invites=get_syndicate_invitations(g.crsid),
                               survey_data=get_survey_data(g.crsid))

    @app.route("/syndicate")
    @auth_decorator
    def syndicate():
        return render_template(
            "syndicate.html",
            syndicate=get_syndicate_for_user(g.crsid),
            max_size=app.config["SYNDICATE_MAXSIZE"][g.user_year])

    @app.route("/syndicate/create", methods=["POST"])
    @auth_decorator
    def syndicate_create():
        invitees = json.loads(request.form['invitees-json'])

        for i in invitees:
            resp = is_syndicatable(i, g.user_year)
            if not resp["ok"]:
                return abort(400)

        if len(invitees) > app.config["SYNDICATE_MAXSIZE"][g.user_year] or len(
                invitees) < 0:
            return abort(400)

        if len(set(invitees)) != len(invitees):
            return abort(400)

        if g.crsid not in invitees:
            return abort(400)

        if g.current_time > app.config["CLOSE_SYNDICATES"]:
            return abort(400)

        log(g.crsid, "created syndicate and invited " + ",".join(invitees))

        create_syndicate(g.crsid, invitees, g.user_year)
        return redirect("/dashboard", 302)

    @app.route("/invite")
    @auth_decorator
    def invite():
        return render_template("invite.html",
                               invites=get_syndicate_invitations(g.crsid))

    @app.route("/invite/accept", methods=["POST"])
    @auth_decorator
    def invite_accept():
        log(g.crsid, "has accepted a syndicate invitation")
        update_invitation(g.crsid, True)
        return redirect("/dashboard", 302)

    @app.route("/invite/reject", methods=["POST"])
    @auth_decorator
    def invite_reject():
        log(g.crsid, "has rejected a syndicate invitation (WARN)")
        update_invitation(g.crsid, False)
        return redirect("/dashboard", 302)

    @app.route("/api/is_syndicatable/<year>/<crsid>")
    @auth_decorator
    def api_is_syndicatable(year, crsid):
        try:
            year = int(year)
        except ValueError:
            return abort(400)
        resp = is_syndicatable(crsid, int(year))
        return json.dumps(resp)

    @app.route("/api/survey_data/" + app.config["COGNITOFORMS_KEY"],
               methods=["POST"])
    @csrf.exempt
    def api_survey_data():
        if request.content_length > 65536:
            return abort(413)

        log_survey_data(request.get_data())
        return import_survey_data(request.get_json())

    @app.route("/survey")
    @auth_decorator
    def survey():
        if g.current_time < app.config["SHOW_SURVEY"]:
            return abort(403)

        return render_template("survey.html",
                               survey_data=get_survey_data(g.crsid))

    @app.route("/")
    def landing():
        try:
            session["_ucam_webauth"]["state"]["principal"]
            return redirect("/dashboard", 302)
        except KeyError:
            return render_template("landing.html")

    @app.route("/about")
    @auth_decorator
    def about():
        return render_template("about.html")

    @app.route("/logout", methods=["POST"])
    def logout():
        session.clear()
        return redirect("/", code=302)

    @app.route("/allocations")
    @auth_decorator
    def allocations():
        if g.current_time < app.config["SHOW_ALLOCATIONS"]:
            return abort(403)

        return render_template("allocations.html")

    @app.route("/review")
    @auth_decorator
    def review():
        if not app.config["ROOM_REVIEWS"]:
            return render_template("review_no.html")

        room = get_allocation_for_user(g.crsid)
        if room is None:
            return abort(403)

        if has_reviewed(g.crsid):
            return render_template("review_thanks.html")

        return render_template("review.html", room=room)

    @app.route("/review", methods=["POST"])
    @auth_decorator
    def leave_review():
        if not app.config["ROOM_REVIEWS"]:
            return abort(403)

        room = get_allocation_for_user(g.crsid)
        if room is None:
            return abort(403)

        if has_reviewed(g.crsid):
            return abort(403)

        if not check_review(request.form):
            return abort(400)

        write_review(g.crsid, room, request.form)
        return render_template("review_thanks.html")

    # cheeky jinja function override so that we can make database calls from the templates
    # this is a bit of a hack but it makes the python code a lot cleaner
    app.jinja_env.globals.update(get_user=get_user)

    # The app is complete and ready to accept requests

    return app
Beispiel #33
0
def create_app(test_config=None):
    app = Flask(__name__)

    app.config.from_mapping(get_config())

    if test_config is not None:
        # load the test config if passed in
        app.config.update(test_config)

    # Connect the db
    db.init_app(app)

    # Set up flask-login
    login_manager = LoginManager()
    login_manager.init_app(app)

    # Set up flask-principal for roles management
    principals = Principal()
    principals.init_app(app)

    # Set up csrf protection
    csrf = CSRFProtect()
    csrf.init_app(app)
    csrf.exempt(login_blueprint)

    app.register_blueprint(login_blueprint)
    app.register_blueprint(admin_blueprint, url_prefix="/admin")
    app.register_blueprint(token_management_blueprint, url_prefix="/user")
    app.register_blueprint(aggregation_unit_blueprint,
                           url_prefix="/spatial_aggregation")

    # Set the log level
    app.before_first_request(
        partial(app.logger.setLevel, app.config["LOG_LEVEL"]))

    if app.config["DEMO_MODE"]:  # Create demo data
        app.before_first_request(make_demodata)
    else:
        # Initialise the database
        app.before_first_request(partial(init_db,
                                         force=app.config["RESET_DB"]))
        # Create an admin user
        app.before_first_request(
            partial(
                add_admin,
                username=app.config["ADMIN_USER"],
                password=app.config["ADMIN_PASSWORD"],
            ))

    app.before_first_request(app.config["DB_IS_SET_UP"].wait
                             )  # Cause workers to wait for db to set up

    @app.after_request
    def set_xsrf_cookie(response):
        """
        Sets the csrf token used by csrf protect as a cookie to allow usage with
        react.
        """
        response.set_cookie("X-CSRF", generate_csrf())
        try:
            current_app.logger.debug(
                f"Logged in user was {flask.g.user.username}:{flask.g.user.id}"
            )
            current_app.logger.debug(flask.session)
        except AttributeError:
            current_app.logger.debug(f"User was not logged in.")
        return response

    @app.errorhandler(CSRFError)
    def handle_csrf_error(e):
        """
        CSRF errors are interpreted as an access denied.
        """
        return "CSRF error", 401

    @app.errorhandler(InvalidUsage)
    def handle_invalid_usage(error):
        response = flask.jsonify(error.to_dict())
        response.status_code = error.status_code
        return response

    @app.before_request
    def before_request():
        """
        Make sessions expire after 20 minutes of inactivity.
        """
        flask.session.permanent = True
        app.permanent_session_lifetime = datetime.timedelta(minutes=20)
        flask.session.modified = True
        flask.g.user = flask_login.current_user
        try:
            current_app.logger.debug(
                f"Logged in user is {flask.g.user.username}:{flask.g.user.id}")
            current_app.logger.debug(flask.session)
        except AttributeError:
            current_app.logger.debug(f"User is not logged in.")

    @login_manager.user_loader
    def load_user(userid):
        """Helper for flask-login."""
        return User.query.filter(User.id == userid).first()

    @identity_loaded.connect_via(app)
    def on_identity_loaded(sender, identity):
        """Helper for flask-principal."""
        # Set the identity user object
        identity.user = current_user

        # Add the UserNeed to the identity
        if hasattr(current_user, "id"):
            identity.provides.add(UserNeed(current_user.id))

        try:
            if current_user.is_admin:
                identity.provides.add(RoleNeed("admin"))
        except AttributeError:
            pass  # Definitely not an admin

    @app.cli.command("get-fernet")
    def make_flowauth_fernet_key():
        """
        Generate a new Fernet key for symmetric encryption of data at
        rest.
        """
        print(f'FLOWAUTH_FERNET_KEY="{Fernet.generate_key().decode()}"')

    # Add flask <command> CLI commands
    app.cli.add_command(demodata)
    app.cli.add_command(init_db_command)
    app.cli.add_command(add_admin_command)
    return app
Beispiel #34
0
def create_app(test_config=None):
    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(
        SECRET_KEY='dev',
        ES_ADDRESS="localhost",
        USER_DATABASE=os.path.join(app.instance_path, 'noos_users.sqlite'),
        HOST="http://*****:*****@noos-citoyens.fr",
        ACCOUNT_CREATION_NEEDS_INVITATION=False,
        AUTH_TOKEN_MAX_AGE=30 * 24 * 3600,  # default 30 days
        RECOVERY_TOKEN_MAX_AGE=24 * 3600,  # default 1 days
        USER_ACCOUNT_LIMIT=0,

        # == SMTP ==
        MAIL_SERVER="localhost",
        MAIL_PORT="465",  # "587"
        MAIL_USE_TLS=True,
        MAIL_USERNAME=None,  # "*****@*****.**",
        MAIL_PASSWORD=None,  #"secret",
    )

    conf_py = app.root_path + '/config.py'

    if os.path.exists(conf_py):
        app.logger.info("reading config %s" % conf_py)
        app.config.from_pyfile(conf_py)

    if test_config is not None:
        app.config.from_mapping(test_config)
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass

    login_manager = LoginManager()
    login_manager.init_app(app)

    # Flask-Mail
    from flask_mail import Mail
    app.mail = Mail(app)

    from flask_wtf.csrf import CSRFProtect
    app.config["WTF_CSRF_CHECK_DEFAULT"] = False
    csrf = CSRFProtect(app)
    csrf.init_app(app)

    from . import datastorage
    datastorage.init_app(app)

    with app.app_context():

        from . import auth
        auth_api = auth.blueprint("auth", __name__, url_prefix='/compte')
        app.register_blueprint(auth_api)

        with app.app_context():
            app.register_error_handler(auth.UserAccountLimitException,
                                       auth.handleUserAccountLimitException)
            app.register_error_handler(auth.EmailExistsValidationError,
                                       auth.handleEmailExistsValidationError)
            app.register_error_handler(auth.EmailNotExistsAccountError,
                                       auth.handleEmailNotExistsAccountError)
            app.register_error_handler(auth.UserAccountNotActiveError,
                                       auth.handleUserAccountNotActiveError)
            app.register_error_handler(auth.SendMailError,
                                       auth.handleSendMailError)

        errors = ((auth.UserLoginNotFoundError, 401, "User not found", ""), (
            401, 401, "Unauthorized",
            """The server could not verify that you are authorized to access the URL requested. "
             You either supplied the wrong credentials (e.g. a bad password), or your browser doesn't understand "
             how to supply the credentials required."""),
                  (404, 404, "Page not found",
                   """The requested URL was not found on the server. 
             If you entered the URL manually please check your spelling and try again."""
                   ), (405, 405, "Method Not Allowed",
                       "The method is not allowed for the requested URL."),
                  (500, 500, "", ""))

        def handle_error(code, title, message):
            def wrapped(e):
                message = e.message if hasattr(e, 'message') else ""
                params = {
                    'error_code': code,
                    'error_title': title,
                    'error_message': message
                }
                load_from_cookie = getattr(g, 'load_from_cookie', False)
                load_from_request = getattr(g, 'load_from_request', False)

                if code == 500:
                    params['error_title'] = "C'est une erreur."
                    params[
                        'error_message'] = "There is something wrong. %s " % message

                if code == 401:
                    return redirect(url_for('index'))
                else:
                    return render_template('40x.xhtml', **params), code

            return wrapped

        for code_or_err, status, title, msg in errors:
            app.register_error_handler(code_or_err,
                                       handle_error(status, title, msg))

    @login_manager.user_loader
    def load_user(uuid):
        return Users.get_one_by('uuid', uuid)

    @app.route('/')
    def index():
        data = {'a': 1, 'b': 2}
        return render_template(
            'index.xhtml',
            title="Noos - plateforme de revendications citoyennes",
            data=data)

    @app.route('/mentionslegales')
    def legal():
        return render_template('mentionslegales.xhtml',
                               title="NOos - Mentions légales")

    @app.route('/donnees')
    def donnees():
        return render_template('donnees.xhtml',
                               title="NOos - Gestion des données")

    @app.route('/test')
    def test_page():
        return render_template('test.xhtml', title='page de test')

    @app.route('/search_propositions', methods=['POST'])
    @login_required
    def test_query():
        params = request.get_json(force=True)
        q = params.get('query', None)
        limit = min(10, params.get('limit', 10))
        start = params.get('start', 0)
        if q:
            results = datastorage.Proposition.simple_search(q, start, limit)
            data = []
            props = results['hits']
            for p in props:
                data.append(p.to_dict())
                data[-1]['id'] = p.meta.id
            return jsonify({"count": results['count'], "hits": data})
        else:
            return jsonify([])

    # pour TinaWebJS
    # info_div.php?ndtype=0&dbtype=csv&query=["gestion","autoroutes","partie"]&gexf=data/p/graph.gexf&n=10
    @app.route('/lookup_propositions', methods=['GET', "OPTIONS"])
    def graph_query():
        if request.method == "OPTIONS":
            return app.make_default_options_response()
        import json
        query = json.loads(request.args.get("query"))
        result = []
        if query:
            props = datastorage.Proposition.simple_search(
                " ".join(query), 0, 50)
            result = [{
                'src': p['cause'],
                'txt': p['content']
            } for p in props['hits']]
        return jsonify({'hits': result})

    #app.add_url_rule('/lookup_propositions', graph_query, methods=["GET"], provide_automatic_options=True)

    @app.route('/newprop', methods=['GET', 'POST'])
    @login_required
    def new_prop():
        if current_user is None:
            return redirect(url_for("auth.login"))
        if request.method == "GET":
            return render_template('newprop.xhtml',
                                   title="faire une proposition")
        else:
            try:
                cause = html.escape(request.form.get('cause'),
                                    quote=False).strip()
                content = html.escape(request.form.get('content'),
                                      quote=False).strip()
                date = datetime.now()
                ip = request.remote_addr
                if ip is None or ip == '':
                    if 'X-Forwarded-For' in request.headers:
                        ip = request.headers.getlist(
                            "X-Forwarded-For")[0].rpartition(',')[-1]
                    else:
                        ip = "127.0.0.1"
                if 5 < len(content) < 500:
                    p = datastorage.Proposition(ip=ip,
                                                uid=current_user.uuid,
                                                cause=cause,
                                                content=content,
                                                date=date)
                    p.save()
                    return redirect(
                        url_for('get_proposition', id=p.meta.id,
                                msg="nouveau"))
                else:
                    return render_template(
                        'newprop.xhtml', title="faire une proposition"
                    )  #todo : message d'erreur avec lien vers les guidelines
            except:
                return render_template('newprop.xhtml',
                                       title="faire une proposition")

    @app.route('/proposition/<string:id>/<string:msg>')
    @app.route('/proposition/<string:id>')
    @app.route('/proposition')
    def get_proposition(id=None, msg=None):
        p = datastorage.Proposition.get(id, ignore=404)
        if p is not None:
            data = p.to_dict()
            if msg:
                data['msg'] = """
                Merci de votre contribution, il vous sera possible prochainement d'accéder à la liste complète de toutes les propositions
                """
            else:
                data['msg'] = ""
            data['id'] = p.meta.id
            user = datastorage.user_db.Users.get_one_by('uuid', p.uid)
            if user is not None:
                data['username'] = user.username
            else:
                data['username'] = "******"
            return render_template('proposition.xhtml',
                                   prop=data,
                                   host=current_app.config['HOST'])
        else:
            return abort(404)

    @app.route('/guide')
    def guide():
        return render_template('guidelines.xhtml',
                               title="Conseils de rédaction")

    @app.route('/isc-explorer')
    def isc():
        return render_template('isc-explorer.xhtml',
                               title="vision d'ensemble des propositions")

    @app.route('/isc-frame')
    def isc_frame():
        return render_template('isc-frame.xhtml',
                               title="vision d'ensemble des propositions")

    @app.route('/datasets')
    def datasets():
        return render_template("datasets.xhtml", title="Données Ouvertes")

    return app
Beispiel #35
0
def create_app(test_config=None):
    app=Flask(__name__)
    try:
        app.config.from_pyfile('free_shark.cfg')
        app.config.from_pyfile('db_config.cfg')
        app.config.from_pyfile('mail_config.cfg')
    except:
        pass
    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
    app.config['WTF_CSRF_ENABLED'] = False
    if test_config is None:
    # load the instance config, if it exists, when not testing
        app.config.from_pyfile('config.py', silent=True)
    else:
        # load the test config if passed in
        app.config.from_mapping(test_config)

    d=load_config_from_envvar()
    app.config.from_mapping(d)
    # ensure the instance folder exists
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass

    debugFlag=app.config['DEBUG']

    db.init_app(app)
    
    app.register_blueprint(auth.bp)
    app.register_blueprint(resources.bp)
    app.register_blueprint(comController.bp)
    app.register_blueprint(admin.bp)


    app.add_template_global(set_var, 'set_var')
    app.add_template_global(get_var, 'get_var')

    app.jinja_env.globals['HEARTBEAT_FLAG']=not debugFlag

    from urllib.parse import urlencode
    from free_shark.utils import replace_dict
    app.jinja_env.filters['urlencode']=urlencode
    app.jinja_env.filters['replace_dict']=replace_dict
    app.jinja_env.filters['set_default']=set_default

    app.register_error_handler(403,frobidden_handler)

    principals = Principal(app)
    principals.init_app(app)

    login_manager=LoginManager()   
    login_manager.init_app(app)

    bootstrap=Bootstrap()
    bootstrap.init_app(app)
  
    csrf=CSRFProtect(app)
    csrf.init_app(app)

    
    from free_shark.utils import api_limiter,mail
    api_limiter.init_app(app)
    mail.init_app(app)

    @login_manager.user_loader
    def load_user(userid):
        try:
            id=int(userid)
            return user.User.get_user_by_id(id)
        except:
            return user.User.get_user_by_token(userid) 



    # a simple page that says hello
    
    @identity_loaded.connect_via(app)
    def on_identity_loaded(sender, identity):
        identity.user=current_user
        if current_user is not None and not current_user.is_anonymous:
            identity.provides.add(UserNeed(current_user.id))
            for role in current_user.role:
                identity.provides.add(RoleNeed(role))
        else:
            identity.provides.add(RoleNeed("anonymous"))

    return app
Beispiel #36
0
        # If no letters are provided, then a pattern must be provided
        if letters == "" and patt == "":
            e = "Need pattern if no letters"
        # ..g for patter ask for letters ... can't match these
        if letters != "" and patt != "":
            for r in patt.replace('.', ''):
                if r not in letters:
                    e="Pattern letter not in letters"
        return e
    # f72e9bac-8434-4aa7-852c-bf16c49917b6er


csrf = CSRFProtect()
app = Flask(__name__) # gunicorn will find you
app.config["SECRET_KEY"] = "row the boat"
csrf.init_app(app)

# has to be / ... otherwise goodluck getting heroku to deploy
@app.route('/')
def index():
    form = WordForm()
    return render_template("index.html", form=form, name="Morris Ombiro")

@app.route('/words', methods=['POST','GET'])
def letters_2_words():
    form = WordForm()
    if form.validate_on_submit():
        letters = form.avail_letters.data
        len_ = form.choose_length.data
        pat_ = form.pattern.data
from flask import Flask
from flask_wtf.csrf import CSRFProtect

csrf = CSRFProtect()

application = Flask(__name__)
csrf.init_app(application)

application.config.from_object('config')

from app.models import Task
from app.controllers import *

if __name__ == '__main__':
    application.run(port=8000)