Exemple #1
0
def _generate_mail_content(schedule, screenshot, name, url):
    if schedule.delivery_type == EmailDeliveryType.attachment:
        images = None
        data = {
            'screenshot.png': screenshot,
        }
        body = __(
            '<b><a href="%(url)s">Explore in Superset</a></b><p></p>',
            name=name,
            url=url,
        )
    elif schedule.delivery_type == EmailDeliveryType.inline:
        # Get the domain from the 'From' address ..
        # and make a message id without the < > in the ends
        domain = parseaddr(config.get('SMTP_MAIL_FROM'))[1].split('@')[1]
        msgid = make_msgid(domain)[1:-1]

        images = {
            msgid: screenshot,
        }
        data = None
        body = __(
            """
            <b><a href="%(url)s">Explore in Superset</a></b><p></p>
            <img src="cid:%(msgid)s">
            """,
            name=name, url=url, msgid=msgid,
        )

    return EmailContent(body, data, images)
Exemple #2
0
def user_reset(token=None):
  user_db = model.User.get_by('token', token)
  if not user_db:
    flask.flash(__('That link is either invalid or expired.'), category='danger')
    return flask.redirect(flask.url_for('welcome'))

  if auth.is_logged_in():
    flask_login.logout_user()
    return flask.redirect(flask.request.path)

  form = UserResetForm()
  if form.validate_on_submit():
    user_db.password_hash = util.password_hash(user_db, form.new_password.data)
    user_db.token = util.uuid()
    user_db.verified = True
    user_db.put()
    flask.flash(__('Your password was changed succesfully.'), category='success')
    return auth.signin_user_db(user_db)

  return flask.render_template(
    'user/user_reset.html',
    title='Reset Password',
    html_class='user-reset',
    form=form,
    user_db=user_db,
  )
Exemple #3
0
def user_verify(token):
  user_db = auth.current_user_db()
  if user_db.token != token:
    flask.flash(__('That link is either invalid or expired.'), category='danger')
    return flask.redirect(flask.url_for('profile'))
  user_db.verified = True
  user_db.token = util.uuid()
  user_db.put()
  flask.flash(__('Hooray! Your email is now verified.'), category='success')
  return flask.redirect(flask.url_for('profile'))
Exemple #4
0
def _get_slice_data(schedule):
    slc = schedule.slice

    slice_url = _get_url_path(
        'Superset.explore_json',
        csv='true',
        form_data=json.dumps({'slice_id': slc.id}),
    )

    # URL to include in the email
    url = _get_url_path(
        'Superset.slice',
        slice_id=slc.id,
    )

    cookies = {}
    for cookie in _get_auth_cookies():
        cookies['session'] = cookie

    response = requests.get(slice_url, cookies=cookies)
    response.raise_for_status()

    # TODO: Move to the csv module
    rows = [r.split(b',') for r in response.content.splitlines()]

    if schedule.delivery_type == EmailDeliveryType.inline:
        data = None

        # Parse the csv file and generate HTML
        columns = rows.pop(0)
        with app.app_context():
            body = render_template(
                'superset/reports/slice_data.html',
                columns=columns,
                rows=rows,
                name=slc.slice_name,
                link=url,
            )

    elif schedule.delivery_type == EmailDeliveryType.attachment:
        data = {
            __('%(name)s.csv', name=slc.slice_name): response.content,
        }
        body = __(
            '<b><a href="%(url)s">Explore in Superset</a></b><p></p>',
            name=slc.slice_name,
            url=url,
        )

    return EmailContent(body, data, None)
Exemple #5
0
def signin():
  next_url = util.get_next_url()
  form = None
  if config.CONFIG_DB.has_email_authentication:
    form = form_with_recaptcha(SignInForm())
    save_request_params()
    if form.validate_on_submit():
      result = get_user_db_from_email(form.email.data, form.password.data)
      if result:
        cache.reset_auth_attempt()
        return signin_user_db(result)
      if result is None:
        form.email.errors.append(__('Email or Password do not match'))
      if result is False:
        return flask.redirect(flask.url_for('welcome'))
    if not form.errors:
      form.next_url.data = next_url

  if form and form.errors:
    cache.bump_auth_attempt()

  return flask.render_template(
    'auth/auth.html',
    title=_('Sign in'),
    html_class='auth',
    next_url=next_url,
    form=form,
    form_type='signin' if config.CONFIG_DB.has_email_authentication else '',
    **urls_for_oauth(next_url)
  )
Exemple #6
0
def user_activate(token):
  if auth.is_logged_in():
    flask_login.logout_user()
    return flask.redirect(flask.request.path)

  user_db = model.User.get_by('token', token)
  if not user_db:
    flask.flash(__('That link is either invalid or expired.'), category='danger')
    return flask.redirect(flask.url_for('welcome'))

  form = UserActivateForm(obj=user_db)
  if form.validate_on_submit():
    form.populate_obj(user_db)
    user_db.password_hash = util.password_hash(user_db, form.password.data)
    user_db.token = util.uuid()
    user_db.verified = True
    user_db.put()
    return auth.signin_user_db(user_db)

  return flask.render_template(
    'user/user_activate.html',
    title='Activate Account',
    html_class='user-activate',
    user_db=user_db,
    form=form,
  )
Exemple #7
0
def signup():
  next_url = util.get_next_url()
  form = None
  if config.CONFIG_DB.has_email_authentication:
    form = form_with_recaptcha(SignUpForm())
    save_request_params()
    if form.validate_on_submit():
      user_db = model.User.get_by('email', form.email.data)
      if user_db:
        form.email.errors.append(__('This email is already taken.'))

      if not form.errors:
        user_db = create_user_db(
          None,
          util.create_name_from_email(form.email.data),
          form.email.data,
          form.email.data,
        )
        user_db.put()
        task.activate_user_notification(user_db)
        cache.bump_auth_attempt()
        return flask.redirect(flask.url_for('welcome'))

  if form and form.errors:
    cache.bump_auth_attempt()

  title = _('Sign up') if config.CONFIG_DB.has_email_authentication else _('Sign in')
  return flask.render_template(
    'auth/auth.html',
    title=title,
    html_class='auth',
    next_url=next_url,
    form=form,
    **urls_for_oauth(next_url)
  )
Exemple #8
0
def notify_user_about_perm_udate(
        granter, user, role, datasource, tpl_name, config):
    msg = render_template(tpl_name, granter=granter, user=user, role=role,
                          datasource=datasource)
    logging.info(msg)
    subject = __('[Superset] Access to the datasource %(name)s was granted',
                 name=datasource.full_name)
    send_email_smtp(user.email, subject, msg, config, bcc=granter.email,
                    dryrun=config.get('EMAIL_NOTIFICATIONS'))
Exemple #9
0
def _register_schedule_menus():
    appbuilder.add_separator('Manage')

    appbuilder.add_view(
        DashboardEmailScheduleView,
        'Dashboard Email Schedules',
        label=__('Dashboard Emails'),
        category='Manage',
        category_label=__('Manage'),
        icon='fa-search')

    appbuilder.add_view(
        SliceEmailScheduleView,
        'Chart Emails',
        label=__('Chart Email Schedules'),
        category='Manage',
        category_label=__('Manage'),
        icon='fa-search')
Exemple #10
0
def facebook_authorized():
  response = facebook.authorized_response()
  if response is None:
    flask.flash(__('You denied the request to sign in.'))
    return flask.redirect(util.get_next_url())

  flask.session['oauth_token'] = (response['access_token'], '')
  me = facebook.get('/me?fields=name,email')
  user_db = retrieve_user_from_facebook(me.data)
  return auth.signin_user_db(user_db)
Exemple #11
0
def google_authorized():
  response = google.authorized_response()
  if response is None:
    flask.flash(__('You denied the request to sign in.'))
    return flask.redirect(util.get_next_url())

  flask.session['oauth_token'] = (response['access_token'], '')
  me = google.get('me', data={'access_token': response['access_token']})
  user_db = retrieve_user_from_google(me.data)
  return auth.signin_user_db(user_db)
Exemple #12
0
def twitter_authorized():
  response = twitter.authorized_response()
  if response is None:
    flask.flash(__('You denied the request to sign in.'))
    return flask.redirect(util.get_next_url())

  flask.session['oauth_token'] = (
    response['oauth_token'],
    response['oauth_token_secret'],
  )
  user_db = retrieve_user_from_twitter(response)
  return auth.signin_user_db(user_db)
Exemple #13
0
def deliver_dashboard(schedule):
    """
    Given a schedule, delivery the dashboard as an email report
    """
    dashboard = schedule.dashboard

    dashboard_url = _get_url_path(
        'Superset.dashboard',
        dashboard_id=dashboard.id,
    )

    # Create a driver, fetch the page, wait for the page to render
    driver = create_webdriver()
    window = config.get('WEBDRIVER_WINDOW')['dashboard']
    driver.set_window_size(*window)
    driver.get(dashboard_url)
    time.sleep(PAGE_RENDER_WAIT)

    # Set up a function to retry once for the element.
    # This is buggy in certain selenium versions with firefox driver
    get_element = getattr(driver, 'find_element_by_class_name')
    element = retry_call(
        get_element,
        fargs=['grid-container'],
        tries=2,
        delay=PAGE_RENDER_WAIT,
    )

    try:
        screenshot = element.screenshot_as_png
    except WebDriverException:
        # Some webdrivers do not support screenshots for elements.
        # In such cases, take a screenshot of the entire page.
        screenshot = driver.screenshot()  # pylint: disable=no-member
    finally:
        destroy_webdriver(driver)

    # Generate the email body and attachments
    email = _generate_mail_content(
        schedule,
        screenshot,
        dashboard.dashboard_title,
        dashboard_url,
    )

    subject = __(
        '%(prefix)s %(title)s',
        prefix=config.get('EMAIL_REPORTS_SUBJECT_PREFIX'),
        title=dashboard.dashboard_title,
    )

    _deliver_email(schedule, subject, email)
Exemple #14
0
def signin_user_db(user_db):
  if not user_db:
    return flask.redirect(flask.url_for('signin'))
  flask_user_db = FlaskUser(user_db)
  auth_params = flask.session.get('auth-params', {
    'next': flask.url_for('welcome'),
    'remember': False,
  })
  flask.session.pop('auth-params', None)
  if flask_login.login_user(flask_user_db, remember=auth_params['remember']):
    user_db.put_async()
    return util.set_locale(
      user_db.locale,
      flask.redirect(util.get_next_url(auth_params['next'])),
    )
  flask.flash(__('Sorry, but you could not sign in.'), category='danger')
  return flask.redirect(flask.url_for('signin'))
Exemple #15
0
def deliver_slice(schedule):
    """
    Given a schedule, delivery the slice as an email report
    """
    if schedule.email_format == SliceEmailReportFormat.data:
        email = _get_slice_data(schedule)
    elif schedule.email_format == SliceEmailReportFormat.visualization:
        email = _get_slice_visualization(schedule)
    else:
        raise RuntimeError('Unknown email report format')

    subject = __(
        '%(prefix)s %(title)s',
        prefix=config.get('EMAIL_REPORTS_SUBJECT_PREFIX'),
        title=schedule.slice.slice_name,
    )

    _deliver_email(schedule, subject, email)
    def save(self):
        datasource = json.loads(request.form.get('data'))
        datasource_id = datasource.get('id')
        datasource_type = datasource.get('type')
        orm_datasource = ConnectorRegistry.get_datasource(
            datasource_type, datasource_id, db.session)

        if not check_ownership(orm_datasource, raise_if_false=False):
            return json_error_response(
                __(
                    'You are not authorized to modify '
                    'this data source configuration'),
                status='401',
            )
        orm_datasource.update_from_object(datasource)
        data = orm_datasource.data
        db.session.commit()
        return self.json_response(data)
Exemple #17
0
def feedback():
  if not config.CONFIG_DB.feedback_email:
    return flask.abort(418)

  form = FeedbackForm(obj=auth.current_user_db())
  if not config.CONFIG_DB.has_anonymous_recaptcha or auth.is_logged_in():
    del form.recaptcha
  if form.validate_on_submit():
    body = '%s\n\n%s' % (form.message.data, form.email.data)
    kwargs = {'reply_to': form.email.data} if form.email.data else {}
    task.send_mail_notification('%s...' % body[:48].strip(), body, **kwargs)
    flask.flash(__('Thank you for your feedback!'), category='success')
    return flask.redirect(flask.url_for('welcome'))

  return flask.render_template(
    'feedback.html',
    title=_('Feedback'),
    html_class='feedback',
    form=form,
  )
Exemple #18
0
    def test_deliver_slice_csv_attachment(self, send_email_smtp, get):
        response = Mock()
        get.return_value = response
        response.raise_for_status.return_value = None
        response.content = self.CSV

        schedule = db.session.query(SliceEmailSchedule).filter_by(
            id=self.slice_schedule).all()[0]

        schedule.email_format = SliceEmailReportFormat.data
        schedule.delivery_type = EmailDeliveryType.attachment

        deliver_slice(schedule)
        send_email_smtp.assert_called_once()

        file_name = __('%(name)s.csv', name=schedule.slice.slice_name)

        self.assertEquals(
            send_email_smtp.call_args[1]['data'][file_name],
            self.CSV,
        )
Exemple #19
0
from ..models.connector import AppsFlyerConnector


# for show all Connectors in system
class AppsFlyerConnectorView(SupersetModelView, DeleteMixin):
    """View For Connector Model."""

    datamodel = SQLAInterface(AppsFlyerConnector)

    list_columns = ['name', 'app_id', 'admin_data_sources']

    add_columns = ['name', 'app_id', 'api_token', 'url_pat']

    edit_columns = add_columns

    @action('web_test', 'web_test', 'web_test', 'fa-rocket')
    def web_test(self, items):
        if not isinstance(items, list):
            items = [items]
        items[0].web_test()
        return redirect('appsflyerconnectorview/list/')


# Register ConnectorView Model View
appbuilder.add_view(AppsFlyerConnectorView,
                    'AppsFlyer',
                    icon='fa-random',
                    category='Connectors',
                    category_icon='fa-rocket',
                    category_label=__('Connectors'))
Exemple #20
0
class DashboardModelView(DashboardMixin, SupersetModelView, DeleteMixin):  # pylint: disable=too-many-ancestors
    route_base = "/dashboard"
    datamodel = SQLAInterface(DashboardModel)
    # TODO disable api_read and api_delete (used by cypress)
    # once we move to ChartRestModelApi
    class_permission_name = "Dashboard"
    method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP

    include_route_methods = RouteMethod.CRUD_SET | {
        RouteMethod.API_READ,
        RouteMethod.API_DELETE,
        "download_dashboards",
    }

    @has_access
    @expose("/list/")
    def list(self) -> FlaskResponse:
        if not is_feature_enabled("ENABLE_REACT_CRUD_VIEWS"):
            return super().list()

        return super().render_app_template()

    @action("mulexport", __("Export"), __("Export dashboards?"), "fa-database")
    def mulexport(  # pylint: disable=no-self-use
        self, items: Union["DashboardModelView",
                           List["DashboardModelView"]]) -> FlaskResponse:
        if not isinstance(items, list):
            items = [items]
        ids = "".join("&id={}".format(d.id) for d in items)
        return redirect("/dashboard/export_dashboards_form?{}".format(ids[1:]))

    @event_logger.log_this
    @has_access
    @expose("/export_dashboards_form")
    def download_dashboards(self) -> FlaskResponse:
        if request.args.get("action") == "go":
            ids = request.args.getlist("id")
            return Response(
                DashboardModel.export_dashboards(ids),
                headers=generate_download_headers("json"),
                mimetype="application/text",
            )
        return self.render_template("superset/export_dashboards.html",
                                    dashboards_url="/dashboard/list")

    def pre_add(self, item: "DashboardModelView") -> None:
        item.slug = item.slug or None
        if item.slug:
            item.slug = item.slug.strip()
            item.slug = item.slug.replace(" ", "-")
            item.slug = re.sub(r"[^\w\-]+", "", item.slug)
        if g.user not in item.owners:
            item.owners.append(g.user)
        utils.validate_json(item.json_metadata)
        utils.validate_json(item.position_json)
        owners = list(item.owners)
        for slc in item.slices:
            slc.owners = list(set(owners) | set(slc.owners))

    def pre_update(self, item: "DashboardModelView") -> None:
        check_ownership(item)
        self.pre_add(item)
def get_datasource_exist_error_msg(full_name):
    return __("Datasource %(name)s already exists", name=full_name)
Exemple #22
0
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):  # noqa
    datamodel = SQLAInterface(models.SqlaTable)

    list_title = _('List Tables')
    show_title = _('Show Table')
    add_title = _('Import a table definition')
    edit_title = _('Edit Table')

    list_columns = [
        'link', 'database_name',
        'changed_by_', 'modified']
    order_columns = ['modified']
    add_columns = ['database', 'schema', 'table_name']
    edit_columns = [
        'table_name', 'sql', 'filter_select_enabled',
        'fetch_values_predicate', 'database', 'schema',
        'description', 'owners',
        'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout',
        'is_sqllab_view', 'template_params',
    ]
    base_filters = [['id', DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ['perm', 'slices']
    related_views = [TableColumnInlineView, SqlMetricInlineView]
    base_order = ('changed_on', 'desc')
    search_columns = (
        'database', 'schema', 'table_name', 'owners', 'is_sqllab_view',
    )
    description_columns = {
        'slices': _(
            'The list of charts associated with this table. By '
            'altering this datasource, you may change how these associated '
            'charts behave. '
            'Also note that charts need to point to a datasource, so '
            'this form will fail at saving if removing charts from a '
            'datasource. If you want to change the datasource for a chart, '
            "overwrite the chart from the 'explore view'"),
        'offset': _('Timezone offset (in hours) for this datasource'),
        'table_name': _(
            'Name of the table that exists in the source database'),
        'schema': _(
            'Schema, as used only in some databases like Postgres, Redshift '
            'and DB2'),
        'description': Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            'markdown</a>'),
        'sql': _(
            'This fields acts a Superset view, meaning that Superset will '
            'run a query against this string as a subquery.',
        ),
        'fetch_values_predicate': _(
            'Predicate applied when fetching distinct value to '
            'populate the filter control component. Supports '
            'jinja template syntax. Applies only when '
            '`Enable Filter Select` is on.',
        ),
        'default_endpoint': _(
            'Redirects to this endpoint when clicking on the table '
            'from the table list'),
        'filter_select_enabled': _(
            "Whether to populate the filter's dropdown in the explore "
            "view's filter section with a list of distinct values fetched "
            'from the backend on the fly'),
        'is_sqllab_view': _(
            "Whether the table was generated by the 'Visualize' flow "
            'in SQL Lab'),
        'template_params': _(
            'A set of parameters that become available in the query using '
            'Jinja templating syntax'),
        'cache_timeout': _(
            'Duration (in seconds) of the caching timeout for this table. '
            'A timeout of 0 indicates that the cache never expires. '
            'Note this defaults to the database timeout if undefined.'),
    }
    label_columns = {
        'slices': _('Associated Charts'),
        'link': _('Table'),
        'changed_by_': _('Changed By'),
        'database': _('Database'),
        'database_name': _('Database'),
        'changed_on_': _('Last Changed'),
        'filter_select_enabled': _('Enable Filter Select'),
        'schema': _('Schema'),
        'default_endpoint': _('Default Endpoint'),
        'offset': _('Offset'),
        'cache_timeout': _('Cache Timeout'),
        'table_name': _('Table Name'),
        'fetch_values_predicate': _('Fetch Values Predicate'),
        'owners': _('Owners'),
        'main_dttm_col': _('Main Datetime Column'),
        'description': _('Description'),
        'is_sqllab_view': _('SQL Lab View'),
        'template_params': _('Template parameters'),
        'modified': _('Modified'),
    }

    def pre_add(self, table):
        with db.session.no_autoflush:
            table_query = db.session.query(models.SqlaTable).filter(
                models.SqlaTable.table_name == table.table_name,
                models.SqlaTable.schema == table.schema,
                models.SqlaTable.database_id == table.database.id)
            if db.session.query(table_query.exists()).scalar():
                raise Exception(
                    get_datasource_exist_error_msg(table.full_name))

        # Fail before adding if the table can't be found
        try:
            table.get_sqla_table_object()
        except Exception:
            raise Exception(_(
                'Table [{}] could not be found, '
                'please double check your '
                'database connection, schema, and '
                'table name').format(table.name))

    def post_add(self, table, flash_message=True):
        table.fetch_metadata()
        security_manager.merge_perm('datasource_access', table.get_perm())
        if table.schema:
            security_manager.merge_perm('schema_access', table.schema_perm)

        if flash_message:
            flash(_(
                'The table was created. '
                'As part of this two-phase configuration '
                'process, you should now click the edit button by '
                'the new table to configure it.'), 'info')

    def post_update(self, table):
        self.post_add(table, flash_message=False)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)

    @expose('/edit/<pk>', methods=['GET', 'POST'])
    @has_access
    def edit(self, pk):
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, basestring):
            return resp
        return redirect('/superset/explore/table/{}/'.format(pk))

    @action(
        'refresh',
        __('Refresh Metadata'),
        __('Refresh column metadata'),
        'fa-refresh')
    def refresh(self, tables):
        if not isinstance(tables, list):
            tables = [tables]
        successes = []
        failures = []
        for t in tables:
            try:
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                'Metadata refreshed for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in successes]))
            flash(success_msg, 'info')
        if len(failures) > 0:
            failure_msg = _(
                'Unable to retrieve metadata for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in failures]))
            flash(failure_msg, 'danger')

        return redirect('/tablemodelview/list/')
Exemple #23
0
    def run(self) -> None:
        engine = self._properties["engine"]
        engine_specs = get_engine_specs()

        if engine in BYPASS_VALIDATION_ENGINES:
            # Skip engines that are only validated onCreate
            return

        if engine not in engine_specs:
            raise InvalidEngineError(
                SupersetError(
                    message=__(
                        'Engine "%(engine)s" is not a valid engine.',
                        engine=engine,
                    ),
                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                    level=ErrorLevel.ERROR,
                    extra={
                        "allowed": list(engine_specs),
                        "provided": engine
                    },
                ), )
        engine_spec = engine_specs[engine]
        if not hasattr(engine_spec, "parameters_schema"):
            raise InvalidEngineError(
                SupersetError(
                    message=__(
                        'Engine "%(engine)s" cannot be configured through parameters.',
                        engine=engine,
                    ),
                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                    level=ErrorLevel.ERROR,
                    extra={
                        "allowed": [
                            name for name, engine_spec in engine_specs.items()
                            if issubclass(engine_spec, BasicParametersMixin)
                        ],
                        "provided":
                        engine,
                    },
                ), )

        # perform initial validation
        errors = engine_spec.validate_parameters(  # type: ignore
            self._properties.get("parameters", {}))
        if errors:
            event_logger.log_with_context(action="validation_error",
                                          engine=engine)
            raise InvalidParametersError(errors)

        serialized_encrypted_extra = self._properties.get(
            "encrypted_extra", "{}")
        try:
            encrypted_extra = json.loads(serialized_encrypted_extra)
        except json.decoder.JSONDecodeError:
            encrypted_extra = {}

        # try to connect
        sqlalchemy_uri = engine_spec.build_sqlalchemy_uri(  # type: ignore
            self._properties.get("parameters"),
            encrypted_extra,
        )
        if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
            sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted
        database = DatabaseDAO.build_db_for_connection_test(
            server_cert=self._properties.get("server_cert", ""),
            extra=self._properties.get("extra", "{}"),
            impersonate_user=self._properties.get("impersonate_user", False),
            encrypted_extra=serialized_encrypted_extra,
        )
        database.set_sqlalchemy_uri(sqlalchemy_uri)
        database.db_engine_spec.mutate_db_for_connection_test(database)
        username = self._actor.username if self._actor is not None else None
        engine = database.get_sqla_engine(user_name=username)
        try:
            with closing(engine.raw_connection()) as conn:
                alive = engine.dialect.do_ping(conn)
        except Exception as ex:
            url = make_url_safe(sqlalchemy_uri)
            context = {
                "hostname": url.host,
                "password": url.password,
                "port": url.port,
                "username": url.username,
                "database": url.database,
            }
            errors = database.db_engine_spec.extract_errors(ex, context)
            raise DatabaseTestConnectionFailedError(errors) from ex

        if not alive:
            raise DatabaseOfflineError(
                SupersetError(
                    message=__("Database is offline."),
                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                    level=ErrorLevel.ERROR,
                ), )
Exemple #24
0
class QueryView(RookModelView):
    datamodel = SQLAInterface(Query)
    list_columns = ['user', 'database', 'status', 'start_time', 'end_time']
    label_columns = {
        'user': _('User'),
        'database': _('Database'),
        'status': _('Status'),
        'start_time': _('Start Time'),
        'end_time': _('End Time'),
    }


appbuilder.add_view(QueryView,
                    'Queries',
                    label=__('Queries'),
                    category='Manage',
                    category_label=__('Manage'),
                    icon='fa-search')


class SavedQueryView(RookModelView, DeleteMixin):
    datamodel = SQLAInterface(SavedQuery)

    list_title = _('List Saved Query')
    show_title = _('Show Saved Query')
    add_title = _('Add Saved Query')
    edit_title = _('Edit Saved Query')

    list_columns = [
        'label', 'user', 'database', 'schema', 'description', 'modified',
Exemple #25
0
    def run(self) -> Optional[Dict[str, Any]]:
        initial_form_data = {}

        if self._permalink_key is not None:
            command = GetExplorePermalinkCommand(self._permalink_key)
            permalink_value = command.run()
            if not permalink_value:
                raise ExplorePermalinkGetFailedError()
            state = permalink_value["state"]
            initial_form_data = state["formData"]
            url_params = state.get("urlParams")
            if url_params:
                initial_form_data["url_params"] = dict(url_params)
        elif self._form_data_key:
            parameters = FormDataCommandParameters(key=self._form_data_key)
            value = GetFormDataCommand(parameters).run()
            initial_form_data = json.loads(value) if value else {}

        message = None

        if not initial_form_data:
            if self._slice_id:
                initial_form_data["slice_id"] = self._slice_id
                if self._form_data_key:
                    message = _(
                        "Form data not found in cache, reverting to chart metadata."
                    )
            elif self._dataset_id:
                initial_form_data[
                    "datasource"
                ] = f"{self._dataset_id}__{self._dataset_type}"
                if self._form_data_key:
                    message = _(
                        "Form data not found in cache, reverting to dataset metadata."
                    )

        form_data, slc = get_form_data(
            use_slice_data=True, initial_form_data=initial_form_data
        )
        try:
            self._dataset_id, self._dataset_type = get_datasource_info(
                self._dataset_id, self._dataset_type, form_data
            )
        except SupersetException:
            self._dataset_id = None
            # fallback unkonw datasource to table type
            self._dataset_type = SqlaTable.type

        dataset: Optional[BaseDatasource] = None
        if self._dataset_id is not None:
            try:
                dataset = DatasourceDAO.get_datasource(
                    db.session, cast(str, self._dataset_type), self._dataset_id
                )
            except DatasetNotFoundError:
                pass
        dataset_name = dataset.name if dataset else _("[Missing Dataset]")

        if dataset:
            if app.config["ENABLE_ACCESS_REQUEST"] and (
                not security_manager.can_access_datasource(dataset)
            ):
                message = __(security_manager.get_datasource_access_error_msg(dataset))
                raise DatasetAccessDeniedError(
                    message=message,
                    dataset_type=self._dataset_type,
                    dataset_id=self._dataset_id,
                )

        viz_type = form_data.get("viz_type")
        if not viz_type and dataset and dataset.default_endpoint:
            raise WrongEndpointError(redirect=dataset.default_endpoint)

        form_data["datasource"] = (
            str(self._dataset_id) + "__" + cast(str, self._dataset_type)
        )

        # On explore, merge legacy and extra filters into the form data
        utils.convert_legacy_filters_into_adhoc(form_data)
        utils.merge_extra_filters(form_data)

        dummy_dataset_data: Dict[str, Any] = {
            "type": self._dataset_type,
            "name": dataset_name,
            "columns": [],
            "metrics": [],
            "database": {"id": 0, "backend": ""},
        }
        try:
            dataset_data = dataset.data if dataset else dummy_dataset_data
        except (SupersetException, SQLAlchemyError):
            dataset_data = dummy_dataset_data

        return {
            "dataset": sanitize_datasource_data(dataset_data),
            "form_data": form_data,
            "slice": slc.data if slc else None,
            "message": message,
        }
Exemple #26
0
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.SqlaTable)
    include_route_methods = RouteMethod.CRUD_SET

    list_title = _("Tables")
    show_title = _("Show Table")
    add_title = _("Import a table definition")
    edit_title = _("Edit Table")

    list_columns = ["link", "database_name", "changed_by_", "modified"]
    order_columns = ["modified"]
    add_columns = ["database", "schema", "table_name"]
    edit_columns = [
        "table_name",
        "sql",
        "filter_select_enabled",
        "fetch_values_predicate",
        "database",
        "schema",
        "description",
        "owners",
        "main_dttm_col",
        "default_endpoint",
        "offset",
        "cache_timeout",
        "is_sqllab_view",
        "template_params",
    ]
    base_filters = [["id", DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ["perm", "slices"]
    related_views = [
        TableColumnInlineView,
        SqlMetricInlineView,
        RowLevelSecurityFiltersModelView,
    ]
    base_order = ("changed_on", "desc")
    search_columns = ("database", "schema", "table_name", "owners",
                      "is_sqllab_view")
    description_columns = {
        "slices":
        _("The list of charts associated with this table. By "
          "altering this datasource, you may change how these associated "
          "charts behave. "
          "Also note that charts need to point to a datasource, so "
          "this form will fail at saving if removing charts from a "
          "datasource. If you want to change the datasource for a chart, "
          "overwrite the chart from the 'explore view'"),
        "offset":
        _("Timezone offset (in hours) for this datasource"),
        "table_name":
        _("Name of the table that exists in the source database"),
        "schema":
        _("Schema, as used only in some databases like Postgres, Redshift "
          "and DB2"),
        "description":
        Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            "markdown</a>"),
        "sql":
        _("This fields acts a Superset view, meaning that Superset will "
          "run a query against this string as a subquery."),
        "fetch_values_predicate":
        _("Predicate applied when fetching distinct value to "
          "populate the filter control component. Supports "
          "jinja template syntax. Applies only when "
          "`Enable Filter Select` is on."),
        "default_endpoint":
        _("Redirects to this endpoint when clicking on the table "
          "from the table list"),
        "filter_select_enabled":
        _("Whether to populate the filter's dropdown in the explore "
          "view's filter section with a list of distinct values fetched "
          "from the backend on the fly"),
        "is_sqllab_view":
        _("Whether the table was generated by the 'Visualize' flow "
          "in SQL Lab"),
        "template_params":
        _("A set of parameters that become available in the query using "
          "Jinja templating syntax"),
        "cache_timeout":
        _("Duration (in seconds) of the caching timeout for this table. "
          "A timeout of 0 indicates that the cache never expires. "
          "Note this defaults to the database timeout if undefined."),
    }
    label_columns = {
        "slices": _("Associated Charts"),
        "link": _("Table"),
        "changed_by_": _("Changed By"),
        "database": _("Database"),
        "database_name": _("Database"),
        "changed_on_": _("Last Changed"),
        "filter_select_enabled": _("Enable Filter Select"),
        "schema": _("Schema"),
        "default_endpoint": _("Default Endpoint"),
        "offset": _("Offset"),
        "cache_timeout": _("Cache Timeout"),
        "table_name": _("Table Name"),
        "fetch_values_predicate": _("Fetch Values Predicate"),
        "owners": _("Owners"),
        "main_dttm_col": _("Main Datetime Column"),
        "description": _("Description"),
        "is_sqllab_view": _("SQL Lab View"),
        "template_params": _("Template parameters"),
        "modified": _("Modified"),
    }
    edit_form_extra_fields = {
        "database":
        QuerySelectField(
            "Database",
            query_factory=lambda: db.session().query(models.Database),
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    def pre_add(self, table: "TableModelView") -> None:
        validate_sqlatable(table)

    def post_add(self,
                 table: "TableModelView",
                 flash_message: bool = True) -> None:
        table.fetch_metadata()
        create_table_permissions(table)
        if flash_message:
            flash(
                _("The table was created. "
                  "As part of this two-phase configuration "
                  "process, you should now click the edit button by "
                  "the new table to configure it."),
                "info",
            )

    def post_update(self, table: "TableModelView") -> None:
        self.post_add(table, flash_message=False)

    def _delete(self, pk: int) -> None:
        DeleteMixin._delete(self, pk)

    @expose("/edit/<pk>", methods=["GET", "POST"])
    @has_access
    def edit(self, pk: int) -> FlaskResponse:
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, str):
            return resp
        return redirect("/superset/explore/table/{}/".format(pk))

    @action("refresh", __("Refresh Metadata"), __("Refresh column metadata"),
            "fa-refresh")
    def refresh(
        self, tables: Union["TableModelView",
                            List["TableModelView"]]) -> FlaskResponse:
        if not isinstance(tables, list):
            tables = [tables]
        successes = []
        failures = []
        for t in tables:
            try:
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                "Metadata refreshed for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in successes]),
            )
            flash(success_msg, "info")
        if len(failures) > 0:
            failure_msg = _(
                "Unable to retrieve metadata for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in failures]),
            )
            flash(failure_msg, "danger")

        return redirect("/tablemodelview/list/")

    @expose("/list/")
    @has_access
    def list(self) -> FlaskResponse:
        if not app.config["ENABLE_REACT_CRUD_VIEWS"]:
            return super().list()

        return super().render_app_template()
Exemple #27
0
logger = logging.getLogger()


class TimeGrain(NamedTuple):  # pylint: disable=too-few-public-methods
    name: str  # TODO: redundant field, remove
    label: str
    function: str
    duration: Optional[str]


QueryStatus = utils.QueryStatus
config = app.config

builtin_time_grains: Dict[Optional[str], str] = {
    None: __("Original value"),
    "PT1S": __("Second"),
    "PT1M": __("Minute"),
    "PT5M": __("5 minute"),
    "PT10M": __("10 minute"),
    "PT15M": __("15 minute"),
    "PT0.5H": __("Half hour"),
    "PT1H": __("Hour"),
    "P1D": __("Day"),
    "P1W": __("Week"),
    "P1M": __("Month"),
    "P0.25Y": __("Quarter"),
    "P1Y": __("Year"),
    "1969-12-28T00:00:00Z/P1W": __("Week starting sunday"),
    "1969-12-29T00:00:00Z/P1W": __("Week starting monday"),
    "P1W/1970-01-03T00:00:00Z": __("Week ending saturday"),
        'broker_endpoint': _("Broker Endpoint"),
    }

    def pre_add(self, cluster):
        security.merge_perm(sm, 'database_access', cluster.perm)

    def pre_update(self, cluster):
        self.pre_add(cluster)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)

appbuilder.add_view(
    DruidClusterModelView,
    name="Druid Clusters",
    label=__("Druid Clusters"),
    icon="fa-cubes",
    category="Sources",
    category_label=__("Sources"),
    category_icon='fa-database',)


class DruidDatasourceModelView(DatasourceModelView, DeleteMixin):  # noqa
    datamodel = SQLAInterface(models.DruidDatasource)

    list_title = _('List Druid Datasource')
    show_title = _('Show Druid Datasource')
    add_title = _('Add Druid Datasource')
    edit_title = _('Edit Druid Datasource')

    list_widget = ListWidgetWithCheckboxes
Exemple #29
0
            obj.end_dttm = obj.start_dttm
        elif obj.end_dttm < obj.start_dttm:
            raise Exception(
                'Annotation end time must be no earlier than start time.')

    def pre_update(self, obj):
        self.pre_add(obj)


class AnnotationLayerModelView(SupersetModelView, DeleteMixin):
    datamodel = SQLAInterface(AnnotationLayer)
    list_columns = ['id', 'name']
    edit_columns = ['name', 'descr']
    add_columns = edit_columns


appbuilder.add_view(AnnotationLayerModelView,
                    'Annotation Layers',
                    label=__('Annotation Layers'),
                    icon='fa-comment',
                    category='Manage',
                    category_label=__('Manage'),
                    category_icon='')
appbuilder.add_view(AnnotationModelView,
                    'Annotations',
                    label=__('Annotations'),
                    icon='fa-comments',
                    category='Manage',
                    category_label=__('Manage'),
                    category_icon='')
class Job_Template_ModelView_Base():
    datamodel = SQLAInterface(Job_Template)
    label_title = '任务模板'
    check_redirect_list_url = '/job_template_modelview/list/?_flt_2_name='
    help_url = conf.get('HELP_URL', {}).get(datamodel.obj.__tablename__,
                                            '') if datamodel else ''
    list_columns = ['project', 'name_title', 'version', 'creator', 'modified']
    show_columns = [
        'project', 'name', 'version', 'describe', 'images_url', 'workdir',
        'entrypoint', 'args_html', 'demo_html', 'env', 'hostAliases',
        'privileged', 'expand_html'
    ]
    add_columns = [
        'project', 'images', 'name', 'version', 'describe', 'workdir',
        'entrypoint', 'volume_mount', 'job_args_definition', 'args', 'env',
        'hostAliases', 'privileged', 'accounts', 'demo', 'expand'
    ]
    edit_columns = add_columns

    base_filters = [["id", Job_Tempalte_Filter, lambda: []]]  # 设置权限过滤器
    base_order = ('id', 'desc')
    order_columns = ['id']
    add_form_query_rel_fields = {
        "images": [["name", Images_Filter, None]],
        "project": [["name", Project_Filter, 'job-template']],
    }
    version_list = [[version, version] for version in ['Alpha', 'Release']]
    edit_form_query_rel_fields = add_form_query_rel_fields
    add_form_extra_fields = {
        "name":
        StringField(
            _(datamodel.obj.lab('name')),
            description='英文名(字母、数字、- 组成),最长50个字符',
            widget=BS3TextFieldWidget(
            ),  # 传给widget函数的是外层的field对象,以及widget函数的参数
            validators=[Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"),
                        Length(1, 54)]),
        "version":
        SelectField(_(datamodel.obj.lab('version')),
                    description="job模板的版本,release版本的模板才能被所有用户看到",
                    widget=Select2Widget(),
                    choices=version_list),
        "volume_mount":
        StringField(
            _(datamodel.obj.lab('volume_mount')),
            description='使用该模板的task,会在保存时,自动添加该挂载',
            widget=BS3TextFieldWidget(
            ),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "workdir":
        StringField(
            _(datamodel.obj.lab('workdir')),
            description='工作目录,不填写将直接使用镜像默认的工作目录',
            widget=BS3TextFieldWidget(
            ),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "entrypoint":
        StringField(
            _(datamodel.obj.lab('entrypoint')),
            description='镜像的入口命令,直接写成单行字符串,例如python xx.py,无需添加[]',
            widget=BS3TextFieldWidget(
            ),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "args":
        StringField(
            _(datamodel.obj.lab('args')),
            description='使用job模板,task需要填写的参数,需要按Job Args Definition的标准写入',
            widget=MyBS3TextAreaFieldWidget(
                rows=10),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "env":
        StringField(
            _(datamodel.obj.lab('env')),
            description=
            '使用模板的task自动添加的环境变量,支持模板变量。书写格式:每行一个环境变量env_key=env_value',
            widget=MyBS3TextAreaFieldWidget(
                rows=3),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "hostAliases":
        StringField(
            _(datamodel.obj.lab('hostAliases')),
            description=
            '添加到容器内的host映射。书写格式:每行一个dns解析记录,ip host1 host2,示例:1.1.1.1 example1.oa.com example2.oa.com',
            widget=MyBS3TextAreaFieldWidget(
                rows=3),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "demo":
        StringField(
            _(datamodel.obj.lab('demo')),
            description='填写demo',
            widget=MyBS3TextAreaFieldWidget(
                rows=10),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "job_args_definition":
        StringField(
            _(datamodel.obj.lab('job_args_definition')),
            description='使用job模板参数的标准填写方式',
            widget=MyCodeArea(code=core.job_template_args_definition()
                              ),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "privileged":
        BooleanField(_(datamodel.obj.lab('privileged')),
                     description='是否启动超级权限'),
        "expand":
        StringField(
            _(datamodel.obj.lab('expand')),
            description='json格式的扩展字段,支持 index:$模板展示顺序号,help_url:$帮助文档地址',
            widget=MyBS3TextAreaFieldWidget(
                rows=3),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
    }
    edit_form_extra_fields = add_form_extra_fields

    # 校验是否是json
    # @pysnooper.snoop(watch_explode=('job_args'))
    def pre_add(self, item):
        if not item.env:
            item.env = ''
        envs = item.env.strip().split('\n')
        envs = [env.strip() for env in envs if env.strip() and '=' in env]
        item.env = '\n'.join(envs)
        if not item.args:
            item.args = '{}'
        item.args = core.validate_job_args(item)

        if not item.expand or not item.expand.strip():
            item.expand = '{}'
        core.validate_json(item.expand)
        item.expand = json.dumps(json.loads(item.expand),
                                 indent=4,
                                 ensure_ascii=False)

        if not item.demo or not item.demo.strip():
            item.demo = '{}'

        core.validate_json(item.demo)

        if item.hostAliases:
            # if not item.images.entrypoint:
            #     raise MyappException('images entrypoint not exist')
            all_host = {}
            all_rows = re.split('\r|\n', item.hostAliases)
            all_rows = [
                all_row.strip() for all_row in all_rows if all_row.strip()
            ]
            for row in all_rows:
                hosts = row.split(' ')
                hosts = [host for host in hosts if host]
                if len(hosts) > 1:
                    if hosts[0] in all_host:
                        all_host[hosts[0]] = all_host[hosts[0]] + hosts[1:]
                    else:
                        all_host[hosts[0]] = hosts[1:]

            hostAliases = ''
            for ip in all_host:
                hostAliases += ip + " " + " ".join(all_host[ip])
                hostAliases += '\n'
            item.hostAliases = hostAliases.strip()

        task_args = json.loads(item.demo)
        job_args = json.loads(item.args)
        item.demo = json.dumps(core.validate_task_args(task_args, job_args),
                               indent=4,
                               ensure_ascii=False)

    # 检测是否具有编辑权限,只有creator和admin可以编辑
    def check_edit_permission(self, item):
        user_roles = [role.name.lower() for role in list(get_user_roles())]
        if "admin" in user_roles:
            return True
        if g.user and g.user.username and hasattr(item, 'created_by'):
            if g.user.username == item.created_by.username:
                return True
        flash('just creator can edit/delete ', 'warning')
        return False

    def pre_update(self, item):
        self.pre_add(item)

    # @pysnooper.snoop()
    def post_list(self, items):
        def sort_expand_index(items, dbsession):
            all = {0: []}
            for item in items:
                try:
                    if item.expand:
                        index = float(json.loads(item.expand).get(
                            'index', 0)) + float(
                                json.loads(item.project.expand).get(
                                    'index', 0)) * 1000
                        if index:
                            if index in all:
                                all[index].append(item)
                            else:
                                all[index] = [item]
                        else:
                            all[0].append(item)
                    else:
                        all[0].append(item)
                except Exception as e:
                    print(e)
            back = []
            for index in sorted(all):
                back.extend(all[index])
                # 当有小数的时候自动转正
                # if float(index)!=int(index):
                #     pass
            return back

        return sort_expand_index(items, db.session)

    @expose("/run", methods=["POST"])
    def run(self):
        request_data = request.json
        job_template_id = request_data.get('job_template_id', '')
        job_template_name = request_data.get('job_template_name', '')
        run_id = request_data.get('run_id', '').replace('_', '-')
        resource_memory = request_data.get('resource_memory', '')
        resource_cpu = request_data.get('resource_cpu', '')
        task_args = request_data.get('args', '')
        if (not job_template_id
                and not job_template_name) or not run_id or task_args == '':
            response = make_response("输入参数不齐全")
            response.status_code = 400
            return response

        job_template = None
        if job_template_id:
            job_template = db.session.query(Job_Template).filter_by(
                id=int(job_template_id)).first()
        elif job_template_name:
            job_template = db.session.query(Job_Template).filter_by(
                name=job_template_name).first()
        if not job_template:
            response = make_response("no job template exist")
            response.status_code = 400
            return response

        from myapp.utils.py.py_k8s import K8s

        k8s = K8s()
        namespace = conf.get('PIPELINE_NAMESPACE')
        pod_name = "venus-" + run_id.replace('_', '-')
        pod_name = pod_name.lower()[:60].strip('-')
        pod = k8s.get_pods(namespace=namespace, pod_name=pod_name)
        # print(pod)
        if pod:
            pod = pod[0]
        # 有历史,直接删除
        if pod:
            k8s.delete_pods(namespace=namespace, pod_name=pod_name)
            time.sleep(2)
            pod = None
        # 没有历史或者没有运行态,直接创建
        if not pod:
            args = []

            job_template_args = json.loads(
                job_template.args) if job_template.args else {}
            for arg_name in task_args:
                arg_type = ''
                for group in job_template_args:
                    for template_arg in job_template_args[group]:
                        if template_arg == arg_name:
                            arg_type = job_template_args[group][
                                template_arg].get('type', '')
                arg_value = task_args[arg_name]
                if arg_value:
                    args.append(arg_name)
                    if arg_type == 'json':
                        args.append(json.dumps(arg_value))
                    else:
                        args.append(arg_value)

            # command = ['sh', '-c','sleep 7200']
            volume_mount = 'kubeflow-cfs-workspace(pvc):/mnt,kubeflow-cfs-archives(pvc):/archives'
            env = job_template.env + "\n"
            env += 'KFJ_TASK_ID=0\n'
            env += 'KFJ_TASK_NAME=' + str('venus-' + run_id) + "\n"
            env += 'KFJ_TASK_NODE_SELECTOR=cpu=true,train=true\n'
            env += 'KFJ_TASK_VOLUME_MOUNT=' + str(volume_mount) + "\n"
            env += 'KFJ_TASK_IMAGES=' + str(job_template.images) + "\n"
            env += 'KFJ_TASK_RESOURCE_CPU=' + str(resource_cpu) + "\n"
            env += 'KFJ_TASK_RESOURCE_MEMORY=' + str(resource_memory) + "\n"
            env += 'KFJ_TASK_RESOURCE_GPU=0\n'
            env += 'KFJ_PIPELINE_ID=0\n'
            env += 'KFJ_RUN_ID=' + run_id + "\n"
            env += 'KFJ_CREATOR=' + str(g.user.username) + "\n"
            env += 'KFJ_RUNNER=' + str(g.user.username) + "\n"
            env += 'KFJ_PIPELINE_NAME=venus\n'
            env += 'KFJ_NAMESPACE=pipeline' + "\n"

            def template_str(src_str):
                rtemplate = Environment(
                    loader=BaseLoader,
                    undefined=DebugUndefined).from_string(src_str)
                des_str = rtemplate.render(
                    creator=g.user.username,
                    datetime=datetime,
                    runner=g.user.username,
                    uuid=uuid,
                    pipeline_id='0',
                    pipeline_name='venus-task',
                    cluster_name=conf.get('ENVIRONMENT'))
                return des_str

            global_envs = json.loads(
                template_str(
                    json.dumps(conf.get('GLOBAL_ENV', {}),
                               indent=4,
                               ensure_ascii=False)))
            for global_env_key in global_envs:
                env += global_env_key + '=' + global_envs[global_env_key] + "\n"

            hostAliases = job_template.hostAliases + "\n" + conf.get(
                'HOSTALIASES', '')
            k8s.create_debug_pod(
                namespace,
                name=pod_name,
                labels={'run-rtx': g.user.username},
                command=None,
                args=args,
                volume_mount=volume_mount,
                working_dir=None,
                node_selector='cpu=true,train=true',
                resource_cpu=resource_cpu,
                resource_memory=resource_memory,
                resource_gpu=0,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
                image_pull_secrets=[job_template.images.repository.hubsecret],
                image=job_template.images.name,
                hostAliases=hostAliases,
                env=env,
                privileged=job_template.privileged,
                accounts=job_template.accounts,
                username=g.user.username)

        try_num = 5
        while (try_num > 0):
            pod = k8s.get_pods(namespace=namespace, pod_name=pod_name)
            # print(pod)
            if pod:
                break
            try_num = try_num - 1
            time.sleep(2)
        if try_num == 0:
            response = make_response("启动时间过长,一分钟后重试")
            response.status_code = 400
            return response

        user_roles = [role.name.lower() for role in list(g.user.roles)]
        if "admin" in user_roles:
            pod_url = conf.get(
                'K8S_DASHBOARD_CLUSTER'
            ) + "#/log/%s/%s/pod?namespace=%s&container=%s" % (
                namespace, pod_name, namespace, pod_name)
        else:
            pod_url = conf.get(
                'K8S_DASHBOARD_PIPELINE'
            ) + "#/log/%s/%s/pod?namespace=%s&container=%s" % (
                namespace, pod_name, namespace, pod_name)
        print(pod_url)
        response = make_response("启动成功,日志地址: %s" % pod_url)
        response.status_code = 200
        return response

    @expose("/listen", methods=["POST"])
    def listen(self):
        request_data = request.json
        run_id = request_data.get('run_id', '').replace('_', '-')
        if not run_id:
            response = make_response("输入参数不齐全")
            response.status_code = 400
            return response

        from myapp.utils.py.py_k8s import K8s
        k8s = K8s()
        namespace = conf.get('PIPELINE_NAMESPACE')
        pod_name = "venus-" + run_id.replace('_', '-')
        pod_name = pod_name.lower()[:60].strip('-')
        pod = k8s.get_pods(namespace=namespace, pod_name=pod_name)
        # print(pod)
        if pod:
            pod = pod[0]
            if type(pod['start_time']) == datetime.datetime:
                pod['start_time'] = pod['start_time'].strftime(
                    "%Y-%d-%m %H:%M:%S")
            print(pod)
            response = make_response(json.dumps(pod))
            response.status_code = 200
            return response
        else:
            response = make_response('no pod')
            response.status_code = 400
            return response

    @action("copy",
            __("Copy Job Template"),
            confirmation=__('Copy Job Template'),
            icon="fa-copy",
            multiple=True,
            single=False)
    def copy(self, job_templates):
        if not isinstance(job_templates, list):
            job_templates = [job_templates]
        try:
            for job_template in job_templates:
                new_job_template = job_template.clone()
                new_job_template.name = new_job_template.name + "_copy_" + uuid.uuid4(
                ).hex[:4]
                new_job_template.created_on = datetime.datetime.now()
                new_job_template.changed_on = datetime.datetime.now()
                db.session.add(new_job_template)
                db.session.commit()
        except InvalidRequestError:
            db.session.rollback()
        except Exception as e:
            raise e
        return redirect(request.referrer)
Exemple #31
0
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
    engine = "snowflake"
    engine_name = "Snowflake"
    force_column_alias_quotes = True
    max_column_name_length = 256

    _time_grain_expressions = {
        None: "{col}",
        "PT1S": "DATE_TRUNC('SECOND', {col})",
        "PT1M": "DATE_TRUNC('MINUTE', {col})",
        "PT5M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
                DATE_TRUNC('HOUR', {col}))",
        "PT10M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
                 DATE_TRUNC('HOUR', {col}))",
        "PT15M": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
                 DATE_TRUNC('HOUR', {col}))",
        "PT0.5H": "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
                  DATE_TRUNC('HOUR', {col}))",
        "PT1H": "DATE_TRUNC('HOUR', {col})",
        "P1D": "DATE_TRUNC('DAY', {col})",
        "P1W": "DATE_TRUNC('WEEK', {col})",
        "P1M": "DATE_TRUNC('MONTH', {col})",
        "P0.25Y": "DATE_TRUNC('QUARTER', {col})",
        "P1Y": "DATE_TRUNC('YEAR', {col})",
    }

    custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[
        str, Any]]] = {
            OBJECT_DOES_NOT_EXIST_REGEX: (
                __("%(object)s does not exist in this database."),
                SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR,
                {},
            ),
            SYNTAX_ERROR_REGEX: (
                __("Please check your query for syntax errors at or "
                   'near "%(syntax_error)s". Then, try running your query again.'
                   ),
                SupersetErrorType.SYNTAX_ERROR,
                {},
            ),
        }

    @classmethod
    def adjust_database_uri(cls,
                            uri: URL,
                            selected_schema: Optional[str] = None) -> None:
        database = uri.database
        if "/" in uri.database:
            database = uri.database.split("/")[0]
        if selected_schema:
            selected_schema = parse.quote(selected_schema, safe="")
            uri.database = database + "/" + selected_schema

    @classmethod
    def epoch_to_dttm(cls) -> str:
        return "DATEADD(S, {col}, '1970-01-01')"

    @classmethod
    def epoch_ms_to_dttm(cls) -> str:
        return "DATEADD(MS, {col}, '1970-01-01')"

    @classmethod
    def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"TO_DATE('{dttm.date().isoformat()}')"
        if tt == utils.TemporalType.DATETIME:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
        if tt == utils.TemporalType.TIMESTAMP:
            return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')"""
        return None

    @staticmethod
    def mutate_db_for_connection_test(database: "Database") -> None:
        """
        By default, snowflake doesn't validate if the user/role has access to the chosen
        database.

        :param database: instance to be mutated
        """
        extra = json.loads(database.extra or "{}")
        engine_params = extra.get("engine_params", {})
        connect_args = engine_params.get("connect_args", {})
        connect_args["validate_default_parameters"] = True
        engine_params["connect_args"] = connect_args
        extra["engine_params"] = engine_params
        database.extra = json.dumps(extra)

    @classmethod
    def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
        """
        Get Snowflake session ID that will be used to cancel all other running
        queries in the same session.

        :param cursor: Cursor instance in which the query will be executed
        :param query: Query instance
        :return: Snowflake Session ID
        """
        cursor.execute("SELECT CURRENT_SESSION()")
        row = cursor.fetchone()
        return row[0]

    @classmethod
    def cancel_query(cls, cursor: Any, query: Query,
                     cancel_query_id: str) -> bool:
        """
        Cancel query in the underlying database.

        :param cursor: New cursor instance to the db of the query
        :param query: Query instance
        :param cancel_query_id: Snowflake Session ID
        :return: True if query cancelled successfully, False otherwise
        """
        try:
            cursor.execute(
                f"SELECT SYSTEM$CANCEL_ALL_QUERIES({cancel_query_id})")
        except Exception:  # pylint: disable=broad-except
            return False

        return True
Exemple #32
0
def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches
    query_id: int,
    rendered_query: str,
    return_results: bool,
    store_results: bool,
    user_name: Optional[str],
    session: Session,
    start_time: Optional[float],
    expand_data: bool,
    log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
    """Executes the sql query returns the results."""
    if store_results and start_time:
        # only asynchronous queries
        stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)

    query = get_query(query_id, session)
    payload: Dict[str, Any] = dict(query_id=query_id)
    database = query.database
    db_engine_spec = database.db_engine_spec
    db_engine_spec.patch()

    if database.allow_run_async and not results_backend:
        raise SupersetErrorException(
            SupersetError(
                message=__("Results backend is not configured."),
                error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR,
                level=ErrorLevel.ERROR,
            )
        )

    # Breaking down into multiple statements
    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
    if not db_engine_spec.run_multiple_statements_as_one:
        statements = parsed_query.get_statements()
        logger.info(
            "Query %s: Executing %i statement(s)", str(query_id), len(statements)
        )
    else:
        statements = [rendered_query]
        logger.info("Query %s: Executing query as a single statement", str(query_id))

    logger.info("Query %s: Set query to 'running'", str(query_id))
    query.status = QueryStatus.RUNNING
    query.start_running_time = now_as_float()
    session.commit()

    # Should we create a table or view from the select?
    if (
        query.select_as_cta
        and query.ctas_method == CtasMethod.TABLE
        and not parsed_query.is_valid_ctas()
    ):
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "CTAS (create table as select) can only be run with a query where "
                    "the last statement is a SELECT. Please make sure your query has "
                    "a SELECT as its last statement. Then, try running your query "
                    "again."
                ),
                error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    if (
        query.select_as_cta
        and query.ctas_method == CtasMethod.VIEW
        and not parsed_query.is_valid_cvas()
    ):
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "CVAS (create view as select) can only be run with a query with "
                    "a single SELECT statement. Please make sure your query has only "
                    "a SELECT statement. Then, try running your query again."
                ),
                error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR,
                level=ErrorLevel.ERROR,
            )
        )

    engine = database.get_sqla_engine(
        schema=query.schema,
        nullpool=True,
        user_name=user_name,
        source=QuerySource.SQL_LAB,
    )
    # Sharing a single connection and cursor across the
    # execution of all statements (if many)
    with closing(engine.raw_connection()) as conn:
        # closing the connection closes the cursor as well
        cursor = conn.cursor()
        cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
        if cancel_query_id is not None:
            query.set_extra_json_key(cancel_query_key, cancel_query_id)
            session.commit()
        statement_count = len(statements)
        for i, statement in enumerate(statements):
            # Check if stopped
            session.refresh(query)
            if query.status == QueryStatus.STOPPED:
                payload.update({"status": query.status})
                return payload

            # For CTAS we create the table only on the last statement
            apply_ctas = query.select_as_cta and (
                query.ctas_method == CtasMethod.VIEW
                or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
            )

            # Run statement
            msg = f"Running statement {i+1} out of {statement_count}"
            logger.info("Query %s: %s", str(query_id), msg)
            query.set_extra_json_key("progress", msg)
            session.commit()
            try:
                result_set = execute_sql_statement(
                    statement,
                    query,
                    user_name,
                    session,
                    cursor,
                    log_params,
                    apply_ctas,
                )
            except SqlLabQueryStoppedException:
                payload.update({"status": QueryStatus.STOPPED})
                return payload
            except Exception as ex:  # pylint: disable=broad-except
                msg = str(ex)
                prefix_message = (
                    f"[Statement {i+1} out of {statement_count}]"
                    if statement_count > 1
                    else ""
                )
                payload = handle_query_error(
                    ex, query, session, payload, prefix_message
                )
                return payload

        # Commit the connection so CTA queries will create the table.
        conn.commit()

    # Success, updating the query entry in database
    query.rows = result_set.size
    query.progress = 100
    query.set_extra_json_key("progress", None)
    if query.select_as_cta:
        query.select_sql = database.select_star(
            query.tmp_table_name,
            schema=query.tmp_schema_name,
            limit=query.limit,
            show_cols=False,
            latest_partition=False,
        )
    query.end_time = now_as_float()

    use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
    data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
        result_set, db_engine_spec, use_arrow_data, expand_data
    )

    # TODO: data should be saved separately from metadata (likely in Parquet)
    payload.update(
        {
            "status": QueryStatus.SUCCESS,
            "data": data,
            "columns": all_columns,
            "selected_columns": selected_columns,
            "expanded_columns": expanded_columns,
            "query": query.to_dict(),
        }
    )
    payload["query"]["state"] = QueryStatus.SUCCESS

    if store_results and results_backend:
        key = str(uuid.uuid4())
        logger.info(
            "Query %s: Storing results in results backend, key: %s", str(query_id), key
        )
        with stats_timing("sqllab.query.results_backend_write", stats_logger):
            with stats_timing(
                "sqllab.query.results_backend_write_serialization", stats_logger
            ):
                serialized_payload = _serialize_payload(
                    payload, cast(bool, results_backend_use_msgpack)
                )
            cache_timeout = database.cache_timeout
            if cache_timeout is None:
                cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]

            compressed = zlib_compress(serialized_payload)
            logger.debug(
                "*** serialized payload size: %i", getsizeof(serialized_payload)
            )
            logger.debug("*** compressed payload size: %i", getsizeof(compressed))
            results_backend.set(key, compressed, cache_timeout)
        query.results_key = key

    query.status = QueryStatus.SUCCESS
    session.commit()

    if return_results:
        # since we're returning results we need to create non-arrow data
        if use_arrow_data:
            (
                data,
                selected_columns,
                all_columns,
                expanded_columns,
            ) = _serialize_and_expand_data(
                result_set, db_engine_spec, False, expand_data
            )
            payload.update(
                {
                    "data": data,
                    "columns": all_columns,
                    "selected_columns": selected_columns,
                    "expanded_columns": expanded_columns,
                }
            )
        return payload

    return None
Exemple #33
0
class BigQueryEngineSpec(BaseEngineSpec):
    """Engine spec for Google's BigQuery

    As contributed by @mxmzdlv on issue #945"""

    engine = "bigquery"
    engine_name = "Google BigQuery"
    max_column_name_length = 128

    parameters_schema = BigQueryParametersSchema()
    default_driver = "bigquery"
    sqlalchemy_uri_placeholder = "bigquery://{project_id}"

    # BigQuery doesn't maintain context when running multiple statements in the
    # same cursor, so we need to run all statements at once
    run_multiple_statements_as_one = True

    allows_hidden_cc_in_orderby = True
    """
    https://www.python.org/dev/peps/pep-0249/#arraysize
    raw_connections bypass the pybigquery query execution context and deal with
    raw dbapi connection directly.
    If this value is not set, the default value is set to 1, as described here,
    https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor

    The default value of 5000 is derived from the pybigquery.
    https://github.com/mxmzdlv/pybigquery/blob/d214bb089ca0807ca9aaa6ce4d5a01172d40264e/pybigquery/sqlalchemy_bigquery.py#L102
    """
    arraysize = 5000

    _date_trunc_functions = {
        "DATE": "DATE_TRUNC",
        "DATETIME": "DATETIME_TRUNC",
        "TIME": "TIME_TRUNC",
        "TIMESTAMP": "TIMESTAMP_TRUNC",
    }

    _time_grain_expressions = {
        None:
        "{col}",
        "PT1S":
        "CAST(TIMESTAMP_SECONDS("
        "UNIX_SECONDS(CAST({col} AS TIMESTAMP))"
        ") AS {type})",
        "PT1M":
        "CAST(TIMESTAMP_SECONDS("
        "60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 60)"
        ") AS {type})",
        "PT5M":
        "CAST(TIMESTAMP_SECONDS("
        "5*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 5*60)"
        ") AS {type})",
        "PT10M":
        "CAST(TIMESTAMP_SECONDS("
        "10*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 10*60)"
        ") AS {type})",
        "PT15M":
        "CAST(TIMESTAMP_SECONDS("
        "15*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 15*60)"
        ") AS {type})",
        "PT30M":
        "CAST(TIMESTAMP_SECONDS("
        "30*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 30*60)"
        ") AS {type})",
        "PT1H":
        "{func}({col}, HOUR)",
        "P1D":
        "{func}({col}, DAY)",
        "P1W":
        "{func}({col}, WEEK)",
        "1969-12-29T00:00:00Z/P1W":
        "{func}({col}, ISOWEEK)",
        "P1M":
        "{func}({col}, MONTH)",
        "P3M":
        "{func}({col}, QUARTER)",
        "P1Y":
        "{func}({col}, YEAR)",
    }

    custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[
        str, Any]]] = {
            CONNECTION_DATABASE_PERMISSIONS_REGEX: (
                __("We were unable to connect to your database. Please "
                   "confirm that your service account has the Viewer "
                   "and Job User roles on the project."),
                SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR,
                {},
            ),
            TABLE_DOES_NOT_EXIST_REGEX: (
                __(
                    'The table "%(table)s" does not exist. '
                    "A valid table must be used to run this query.", ),
                SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR,
                {},
            ),
            COLUMN_DOES_NOT_EXIST_REGEX: (
                __('We can\'t seem to resolve column "%(column)s" at line %(location)s.'
                   ),
                SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
                {},
            ),
            SCHEMA_DOES_NOT_EXIST_REGEX: (
                __('The schema "%(schema)s" does not exist. '
                   "A valid schema must be used to run this query."),
                SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR,
                {},
            ),
            SYNTAX_ERROR_REGEX: (
                __("Please check your query for syntax errors at or near "
                   '"%(syntax_error)s". Then, try running your query again.'),
                SupersetErrorType.SYNTAX_ERROR,
                {},
            ),
        }

    @classmethod
    def convert_dttm(
            cls,
            target_type: str,
            dttm: datetime,
            db_extra: Optional[Dict[str, Any]] = None) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"CAST('{dttm.date().isoformat()}' AS DATE)"
        if tt == utils.TemporalType.DATETIME:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
        if tt == utils.TemporalType.TIME:
            return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
        if tt == utils.TemporalType.TIMESTAMP:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
        return None

    @classmethod
    def fetch_data(cls,
                   cursor: Any,
                   limit: Optional[int] = None) -> List[Tuple[Any, ...]]:
        data = super().fetch_data(cursor, limit)
        # Support type BigQuery Row, introduced here PR #4071
        # google.cloud.bigquery.table.Row
        if data and type(data[0]).__name__ == "Row":
            data = [r.values() for r in data]  # type: ignore
        return data

    @staticmethod
    def _mutate_label(label: str) -> str:
        """
        BigQuery field_name should start with a letter or underscore and contain only
        alphanumeric characters. Labels that start with a number are prefixed with an
        underscore. Any unsupported characters are replaced with underscores and an
        md5 hash is added to the end of the label to avoid possible collisions.

        :param label: Expected expression label
        :return: Conditionally mutated label
        """
        label_hashed = "_" + md5_sha_from_str(label)

        # if label starts with number, add underscore as first character
        label_mutated = "_" + label if re.match(r"^\d", label) else label

        # replace non-alphanumeric characters with underscores
        label_mutated = re.sub(r"[^\w]+", "_", label_mutated)
        if label_mutated != label:
            # add first 5 chars from md5 hash to label to avoid possible collisions
            label_mutated += label_hashed[:6]

        return label_mutated

    @classmethod
    def _truncate_label(cls, label: str) -> str:
        """BigQuery requires column names start with either a letter or
        underscore. To make sure this is always the case, an underscore is prefixed
        to the md5 hash of the original label.

        :param label: expected expression label
        :return: truncated label
        """
        return "_" + md5_sha_from_str(label)

    @classmethod
    def normalize_indexes(
            cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Normalizes indexes for more consistency across db engines

        :param indexes: Raw indexes as returned by SQLAlchemy
        :return: cleaner, more aligned index definition
        """
        normalized_idxs = []
        # Fixing a bug/behavior observed in pybigquery==0.4.15 where
        # the index's `column_names` == [None]
        # Here we're returning only non-None indexes
        for ix in indexes:
            column_names = ix.get("column_names") or []
            ix["column_names"] = [
                col for col in column_names if col is not None
            ]
            if ix["column_names"]:
                normalized_idxs.append(ix)
        return normalized_idxs

    @classmethod
    def extra_table_metadata(cls, database: "Database", table_name: str,
                             schema_name: Optional[str]) -> Dict[str, Any]:
        indexes = database.get_indexes(table_name, schema_name)
        if not indexes:
            return {}
        partitions_columns = [
            index.get("column_names", []) for index in indexes
            if index.get("name") == "partition"
        ]
        cluster_columns = [
            index.get("column_names", []) for index in indexes
            if index.get("name") == "clustering"
        ]
        return {
            "partitions": {
                "cols": partitions_columns
            },
            "clustering": {
                "cols": cluster_columns
            },
        }

    @classmethod
    def epoch_to_dttm(cls) -> str:
        return "TIMESTAMP_SECONDS({col})"

    @classmethod
    def epoch_ms_to_dttm(cls) -> str:
        return "TIMESTAMP_MILLIS({col})"

    @classmethod
    def df_to_sql(
        cls,
        database: "Database",
        table: Table,
        df: pd.DataFrame,
        to_sql_kwargs: Dict[str, Any],
    ) -> None:
        """
        Upload data from a Pandas DataFrame to a database.

        Calls `pandas_gbq.DataFrame.to_gbq` which requires `pandas_gbq` to be installed.

        Note this method does not create metadata for the table.

        :param database: The database to upload the data to
        :param table: The table to upload the data to
        :param df: The dataframe with data to be uploaded
        :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
        """

        try:
            # pylint: disable=import-outside-toplevel
            import pandas_gbq
            from google.oauth2 import service_account
        except ImportError as ex:
            raise Exception(
                "Could not import libraries `pandas_gbq` or `google.oauth2`, which are "
                "required to be installed in your environment in order "
                "to upload data to BigQuery") from ex

        if not table.schema:
            raise Exception("The table schema must be defined")

        engine = cls.get_engine(database)
        to_gbq_kwargs = {
            "destination_table": str(table),
            "project_id": engine.url.host
        }

        # Add credentials if they are set on the SQLAlchemy dialect.
        creds = engine.dialect.credentials_info

        if creds:
            to_gbq_kwargs[
                "credentials"] = service_account.Credentials.from_service_account_info(
                    creds)

        # Only pass through supported kwargs.
        supported_kwarg_keys = {"if_exists"}

        for key in supported_kwarg_keys:
            if key in to_sql_kwargs:
                to_gbq_kwargs[key] = to_sql_kwargs[key]

        pandas_gbq.to_gbq(df, **to_gbq_kwargs)

    @classmethod
    def build_sqlalchemy_uri(
        cls,
        parameters: BigQueryParametersType,
        encrypted_extra: Optional[Dict[str, Any]] = None,
    ) -> str:
        query = parameters.get("query", {})
        query_params = urllib.parse.urlencode(query)

        if encrypted_extra:
            credentials_info = encrypted_extra.get("credentials_info")
            if isinstance(credentials_info, str):
                credentials_info = json.loads(credentials_info)
            project_id = credentials_info.get("project_id")
        if not encrypted_extra:
            raise ValidationError("Missing service credentials")

        if project_id:
            return f"{cls.default_driver}://{project_id}/?{query_params}"

        raise ValidationError("Invalid service credentials")

    @classmethod
    def get_parameters_from_uri(
            cls,
            uri: str,
            encrypted_extra: Optional[Dict[str, str]] = None) -> Any:
        value = make_url_safe(uri)

        # Building parameters from encrypted_extra and uri
        if encrypted_extra:
            return {**encrypted_extra, "query": value.query}

        raise ValidationError("Invalid service credentials")

    @classmethod
    def get_dbapi_exception_mapping(
            cls) -> Dict[Type[Exception], Type[Exception]]:
        # pylint: disable=import-outside-toplevel
        from google.auth.exceptions import DefaultCredentialsError

        return {DefaultCredentialsError: SupersetDBAPIDisconnectionError}

    @classmethod
    def validate_parameters(
        cls,
        parameters: BigQueryParametersType  # pylint: disable=unused-argument
    ) -> List[SupersetError]:
        return []

    @classmethod
    def parameters_json_schema(cls) -> Any:
        """
        Return configuration parameters as OpenAPI.
        """
        if not cls.parameters_schema:
            return None

        spec = APISpec(
            title="Database Parameters",
            version="1.0.0",
            openapi_version="3.0.0",
            plugins=[ma_plugin],
        )

        ma_plugin.init_spec(spec)
        ma_plugin.converter.add_attribute_function(encrypted_field_properties)
        spec.components.schema(cls.__name__, schema=cls.parameters_schema)
        return spec.to_dict()["components"]["schemas"][cls.__name__]

    @classmethod
    def select_star(  # pylint: disable=too-many-arguments
        cls,
        database: "Database",
        table_name: str,
        engine: Engine,
        schema: Optional[str] = None,
        limit: int = 100,
        show_cols: bool = False,
        indent: bool = True,
        latest_partition: bool = True,
        cols: Optional[List[Dict[str, Any]]] = None,
    ) -> str:
        """
        Remove array structures from `SELECT *`.

        BigQuery supports structures and arrays of structures, eg:

            author STRUCT<name STRING, email STRING>
            trailer ARRAY<STRUCT<key STRING, value STRING>>

        When loading metadata for a table each key in the struct is displayed as a
        separate pseudo-column, eg:

            - author
            - author.name
            - author.email
            - trailer
            - trailer.key
            - trailer.value

        When generating the `SELECT *` statement we want to remove any keys from
        structs inside an array, since selecting them results in an error. The correct
        select statement should look like this:

            SELECT
              `author`,
              `author`.`name`,
              `author`.`email`,
              `trailer`
            FROM
              table

        Selecting `trailer.key` or `trailer.value` results in an error, as opposed to
        selecting `author.name`, since they are keys in a structure inside an array.

        This method removes any array pseudo-columns.
        """
        if cols:
            # For arrays of structs, remove the child columns, otherwise the query
            # will fail.
            array_prefixes = {
                col["name"]
                for col in cols if isinstance(col["type"], sqltypes.ARRAY)
            }
            cols = [
                col for col in cols if "." not in col["name"]
                or col["name"].split(".")[0] not in array_prefixes
            ]

        return super().select_star(
            database,
            table_name,
            engine,
            schema,
            limit,
            show_cols,
            indent,
            latest_partition,
            cols,
        )

    @classmethod
    def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
        """
        Label columns using their fully qualified name.

        BigQuery supports columns of type `struct`, which are basically dictionaries.
        When loading metadata for a table with struct columns, each key in the struct
        is displayed as a separate pseudo-column, eg:

            author STRUCT<name STRING, email STRING>

        Will be shown as 3 columns:

            - author
            - author.name
            - author.email

        If we select those fields:

            SELECT `author`, `author`.`name`, `author`.`email` FROM table

        The resulting columns will be called "author", "name", and "email", This may
        result in a clash with other columns. To prevent that, we explicitly label
        the columns using their fully qualified name, so we end up with "author",
        "author__name" and "author__email", respectively.
        """
        return [
            column(c["name"]).label(c["name"].replace(".", "__")) for c in cols
        ]
Exemple #34
0
def custom_form_factory(form, field_types=None, field_slugs=None,
                        excluded_field_types=None,
                        registration_fields=False):
    fields = (CustomField.query.filter_by(meeting_id=g.meeting.id)
              .order_by(CustomField.sort))

    if field_types:
        fields = fields.filter(CustomField.field_type.in_(field_types))

    if field_slugs:
        fields = fields.filter(CustomField.slug.in_(field_slugs))

    if excluded_field_types:
        fields = fields.filter(
            ~CustomField.field_type.in_(excluded_field_types))

    if registration_fields:
        fields = fields.for_registration()

    if getattr(form, 'CUSTOM_FIELDS_TYPE', None):
        fields = fields.filter_by(custom_field_type=form.CUSTOM_FIELDS_TYPE)

    form_attrs = {
        '_custom_fields': OrderedMultiDict({c.slug: c for c in fields}),
    }

    for f in fields:
        attrs = {'label': unicode(CustomFieldLabel(f.label)),
                 'validators': [],
                 'render_kw': {},
                 'description': f.hint}

        data = _CUSTOM_FIELDS_MAP[f.field_type.code]

        # overwrite data if _CUSTOM_FIELDS_MAP attribute is present on form
        form_fields_map = getattr(form, '_CUSTOM_FIELDS_MAP', None)
        if form_fields_map:
            try:
                data = form_fields_map[f.field_type.code]
            except KeyError:
                pass

        if f.required:
            attrs['validators'].append(DataRequired())
        attrs['validators'].extend(data.get('validators', []))

        if f.max_length:
            attrs['validators'].append(Length(max=f.max_length))

        if f.field_type.code == CustomField.SELECT:
            query = CustomFieldChoice.query.filter_by(custom_field=f)
            attrs['choices'] = [(unicode(c.value), __(c.value.english))
                                for c in query]
            if not f.required:
                attrs['choices'] = [('', '---')] + attrs['choices']
            if f.slug == 'title':
                attrs['choices'] = [choice for choice in attrs['choices']
                                    if choice[0] in app.config['TITLE_CHOICES']]
            attrs['coerce'] = unicode

        if f.field_type.code == CustomField.LANGUAGE:
            attrs['choices'] = [i for i in Participant.LANGUAGE_CHOICES
                                if i[0].lower() in app.config['TRANSLATIONS']]
            if not f.required:
                attrs['choices'] = [('', '---')] + attrs['choices']

            attrs['coerce'] = unicode

        if f.field_type.code == CustomField.CATEGORY:
            query = Category.get_categories_for_meeting(
                form.CUSTOM_FIELDS_TYPE)
            if registration_fields:
                query = query.filter_by(visible_on_registration_form=True)
            attrs['choices'] = [(c.id, c) for c in query]
            attrs['coerce'] = int

        if f.field_type.code in (CustomField.MULTI_CHECKBOX, CustomField.RADIO):
            query = CustomFieldChoice.query.filter_by(custom_field=f)
            attrs['choices'] = [(unicode(c.value), c.value) for c in query]
            attrs['coerce'] = unicode

        if f.field_type.code == CustomField.IMAGE and f.photo_size and f.photo_size.code:
            attrs['render_kw']["data-photoSize"] = f.photo_size.code
            for coord in ("x1", "y1", "x2", "y2"):
                form_attrs['%s_%s_' % (f.slug, coord)] = HiddenField(default=0)

        # set field to form
        # _set_rules_for_custom_fields(f, attrs)
        field = data['field'](**attrs)
        setattr(field, 'field_type', f.field_type.code)
        form_attrs[f.slug] = field

    form_attrs['rules'] = Rule.get_rules_for_fields(fields)
    return type(form)(form.__name__, (form,), form_attrs)
        elif obj.end_dttm < obj.start_dttm:
            raise Exception('Annotation end time must be no earlier than start time.')

    def pre_update(self, obj):
        self.pre_add(obj)


class AnnotationLayerModelView(SupersetModelView, DeleteMixin):
    datamodel = SQLAInterface(AnnotationLayer)
    list_columns = ['id', 'name']
    edit_columns = ['name', 'descr']
    add_columns = edit_columns


appbuilder.add_view(
    AnnotationLayerModelView,
    'Annotation Layers',
    label=__('Annotation Layers'),
    icon='fa-comment',
    category='Manage',
    category_label=__('Manage'),
    category_icon='')
appbuilder.add_view(
    AnnotationModelView,
    'Annotations',
    label=__('Annotations'),
    icon='fa-comments',
    category='Manage',
    category_label=__('Manage'),
    category_icon='')
    def init_views(self) -> None:
        #
        # We're doing local imports, as several of them import
        # models which in turn try to import
        # the global Flask app
        #
        # pylint: disable=import-outside-toplevel,too-many-locals,too-many-statements
        from superset.annotation_layers.annotations.api import AnnotationRestApi
        from superset.annotation_layers.api import AnnotationLayerRestApi
        from superset.async_events.api import AsyncEventsRestApi
        from superset.cachekeys.api import CacheRestApi
        from superset.charts.api import ChartRestApi
        from superset.charts.data.api import ChartDataRestApi
        from superset.connectors.druid.views import (
            Druid,
            DruidClusterModelView,
            DruidColumnInlineView,
            DruidDatasourceModelView,
            DruidMetricInlineView,
        )
        from superset.connectors.sqla.views import (
            RowLevelSecurityFiltersModelView,
            SqlMetricInlineView,
            TableColumnInlineView,
            TableModelView,
        )
        from superset.css_templates.api import CssTemplateRestApi
        from superset.dashboards.api import DashboardRestApi
        from superset.dashboards.filter_sets.api import FilterSetRestApi
        from superset.dashboards.filter_state.api import DashboardFilterStateRestApi
        from superset.databases.api import DatabaseRestApi
        from superset.datasets.api import DatasetRestApi
        from superset.datasets.columns.api import DatasetColumnsRestApi
        from superset.datasets.metrics.api import DatasetMetricRestApi
        from superset.explore.form_data.api import ExploreFormDataRestApi
        from superset.queries.api import QueryRestApi
        from superset.queries.saved_queries.api import SavedQueryRestApi
        from superset.reports.api import ReportScheduleRestApi
        from superset.reports.logs.api import ReportExecutionLogRestApi
        from superset.security.api import SecurityRestApi
        from superset.views.access_requests import AccessRequestsModelView
        from superset.views.alerts import (
            AlertLogModelView,
            AlertModelView,
            AlertObservationModelView,
            AlertView,
            ReportView,
        )
        from superset.views.annotations import (
            AnnotationLayerModelView,
            AnnotationModelView,
        )
        from superset.views.api import Api
        from superset.views.chart.views import SliceAsync, SliceModelView
        from superset.views.core import Superset
        from superset.views.css_templates import (
            CssTemplateAsyncModelView,
            CssTemplateModelView,
        )
        from superset.views.dashboard.views import (
            Dashboard,
            DashboardModelView,
            DashboardModelViewAsync,
        )
        from superset.views.database.views import (
            ColumnarToDatabaseView,
            CsvToDatabaseView,
            DatabaseView,
            ExcelToDatabaseView,
        )
        from superset.views.datasource.views import Datasource
        from superset.views.dynamic_plugins import DynamicPluginsView
        from superset.views.key_value import KV
        from superset.views.log.api import LogRestApi
        from superset.views.log.views import LogModelView
        from superset.views.redirects import R
        from superset.views.schedules import (
            DashboardEmailScheduleView,
            SliceEmailScheduleView,
        )
        from superset.views.sql_lab import (
            SavedQueryView,
            SavedQueryViewApi,
            SqlLab,
            TableSchemaView,
            TabStateView,
        )
        from superset.views.tags import TagView

        #
        # Setup API views
        #
        appbuilder.add_api(AnnotationRestApi)
        appbuilder.add_api(AnnotationLayerRestApi)
        appbuilder.add_api(AsyncEventsRestApi)
        appbuilder.add_api(CacheRestApi)
        appbuilder.add_api(ChartRestApi)
        appbuilder.add_api(ChartDataRestApi)
        appbuilder.add_api(CssTemplateRestApi)
        appbuilder.add_api(DashboardFilterStateRestApi)
        appbuilder.add_api(DashboardRestApi)
        appbuilder.add_api(DatabaseRestApi)
        appbuilder.add_api(DatasetRestApi)
        appbuilder.add_api(DatasetColumnsRestApi)
        appbuilder.add_api(DatasetMetricRestApi)
        appbuilder.add_api(ExploreFormDataRestApi)
        appbuilder.add_api(FilterSetRestApi)
        appbuilder.add_api(QueryRestApi)
        appbuilder.add_api(ReportScheduleRestApi)
        appbuilder.add_api(ReportExecutionLogRestApi)
        appbuilder.add_api(SavedQueryRestApi)
        #
        # Setup regular views
        #
        appbuilder.add_link(
            "Home",
            label=__("Home"),
            href="/superset/welcome/",
            cond=lambda: bool(appbuilder.app.config["LOGO_TARGET_PATH"]),
        )
        appbuilder.add_view(
            AnnotationLayerModelView,
            "Annotation Layers",
            label=__("Annotation Layers"),
            icon="fa-comment",
            category="Manage",
            category_label=__("Manage"),
            category_icon="",
        )
        appbuilder.add_view(
            DashboardModelView,
            "Dashboards",
            label=__("Dashboards"),
            icon="fa-dashboard",
            category="",
            category_icon="",
        )
        appbuilder.add_view(
            SliceModelView,
            "Charts",
            label=__("Charts"),
            icon="fa-bar-chart",
            category="",
            category_icon="",
        )
        appbuilder.add_view(
            DynamicPluginsView,
            "Plugins",
            label=__("Plugins"),
            category="Manage",
            category_label=__("Manage"),
            icon="fa-puzzle-piece",
            menu_cond=lambda: feature_flag_manager.is_feature_enabled(
                "DYNAMIC_PLUGINS"),
        )
        appbuilder.add_view(
            CssTemplateModelView,
            "CSS Templates",
            label=__("CSS Templates"),
            icon="fa-css3",
            category="Manage",
            category_label=__("Manage"),
            category_icon="",
        )
        appbuilder.add_view(
            RowLevelSecurityFiltersModelView,
            "Row Level Security",
            label=__("Row Level Security"),
            category="Security",
            category_label=__("Security"),
            icon="fa-lock",
            menu_cond=lambda: feature_flag_manager.is_feature_enabled(
                "ROW_LEVEL_SECURITY"),
        )

        #
        # Setup views with no menu
        #
        appbuilder.add_view_no_menu(Api)
        appbuilder.add_view_no_menu(CssTemplateAsyncModelView)
        appbuilder.add_view_no_menu(CsvToDatabaseView)
        appbuilder.add_view_no_menu(ExcelToDatabaseView)
        appbuilder.add_view_no_menu(ColumnarToDatabaseView)
        appbuilder.add_view_no_menu(Dashboard)
        appbuilder.add_view_no_menu(DashboardModelViewAsync)
        appbuilder.add_view_no_menu(Datasource)
        appbuilder.add_view_no_menu(KV)
        appbuilder.add_view_no_menu(R)
        appbuilder.add_view_no_menu(SavedQueryView)
        appbuilder.add_view_no_menu(SavedQueryViewApi)
        appbuilder.add_view_no_menu(SliceAsync)
        appbuilder.add_view_no_menu(SqlLab)
        appbuilder.add_view_no_menu(SqlMetricInlineView)
        appbuilder.add_view_no_menu(AnnotationModelView)
        appbuilder.add_view_no_menu(Superset)
        appbuilder.add_view_no_menu(TableColumnInlineView)
        appbuilder.add_view_no_menu(TableModelView)
        appbuilder.add_view_no_menu(TableSchemaView)
        appbuilder.add_view_no_menu(TabStateView)
        appbuilder.add_view_no_menu(TagView)

        #
        # Add links
        #
        appbuilder.add_link(
            "Import Dashboards",
            label=__("Import Dashboards"),
            href="/superset/import_dashboards/",
            icon="fa-cloud-upload",
            category="Manage",
            category_label=__("Manage"),
            category_icon="fa-wrench",
            cond=lambda: not feature_flag_manager.is_feature_enabled(
                "VERSIONED_EXPORT"),
        )
        appbuilder.add_link(
            "SQL Editor",
            label=_("SQL Editor"),
            href="/superset/sqllab/",
            category_icon="fa-flask",
            icon="fa-flask",
            category="SQL Lab",
            category_label=__("SQL Lab"),
        )
        appbuilder.add_link(
            __("Saved Queries"),
            href="/savedqueryview/list/",
            icon="fa-save",
            category="SQL Lab",
        )
        appbuilder.add_link(
            "Query Search",
            label=_("Query History"),
            href="/superset/sqllab/history/",
            icon="fa-search",
            category_icon="fa-flask",
            category="SQL Lab",
            category_label=__("SQL Lab"),
        )
        appbuilder.add_view(
            DatabaseView,
            "Databases",
            label=__("Databases"),
            icon="fa-database",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-database",
        )
        appbuilder.add_link(
            "Datasets",
            label=__("Datasets"),
            href="/tablemodelview/list/",
            icon="fa-table",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-table",
        )
        appbuilder.add_separator("Data")

        appbuilder.add_api(LogRestApi)
        appbuilder.add_view(
            LogModelView,
            "Action Log",
            label=__("Action Log"),
            category="Security",
            category_label=__("Security"),
            icon="fa-list-ol",
            menu_cond=lambda: (self.config["FAB_ADD_SECURITY_VIEWS"] and self.
                               config["SUPERSET_LOG_VIEW"]),
        )
        appbuilder.add_api(SecurityRestApi)
        #
        # Conditionally setup email views
        #
        if self.config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
            logging.warning(
                "ENABLE_SCHEDULED_EMAIL_REPORTS "
                "is deprecated and will be removed in version 2.0.0")

        appbuilder.add_separator(
            "Manage",
            cond=lambda: self.config["ENABLE_SCHEDULED_EMAIL_REPORTS"])
        appbuilder.add_view(
            DashboardEmailScheduleView,
            "Dashboard Email Schedules",
            label=__("Dashboard Emails"),
            category="Manage",
            category_label=__("Manage"),
            icon="fa-search",
            menu_cond=lambda: self.config["ENABLE_SCHEDULED_EMAIL_REPORTS"],
        )
        appbuilder.add_view(
            SliceEmailScheduleView,
            "Chart Emails",
            label=__("Chart Email Schedules"),
            category="Manage",
            category_label=__("Manage"),
            icon="fa-search",
            menu_cond=lambda: self.config["ENABLE_SCHEDULED_EMAIL_REPORTS"],
        )

        if self.config["ENABLE_ALERTS"]:
            logging.warning(
                "ENABLE_ALERTS is deprecated and will be removed in version 2.0.0"
            )

        appbuilder.add_view(
            AlertModelView,
            "Alerts",
            label=__("Alerts"),
            category="Manage",
            category_label=__("Manage"),
            icon="fa-exclamation-triangle",
            menu_cond=lambda: bool(self.config["ENABLE_ALERTS"]),
        )
        appbuilder.add_view_no_menu(AlertLogModelView)
        appbuilder.add_view_no_menu(AlertObservationModelView)

        appbuilder.add_view(
            AlertView,
            "Alerts & Report",
            label=__("Alerts & Reports"),
            category="Manage",
            category_label=__("Manage"),
            icon="fa-exclamation-triangle",
            menu_cond=lambda: feature_flag_manager.is_feature_enabled(
                "ALERT_REPORTS"),
        )
        appbuilder.add_view_no_menu(ReportView)

        appbuilder.add_view(
            AccessRequestsModelView,
            "Access requests",
            label=__("Access requests"),
            category="Security",
            category_label=__("Security"),
            icon="fa-table",
            menu_cond=lambda: bool(self.config["ENABLE_ACCESS_REQUEST"]),
        )

        #
        # Druid Views
        #
        appbuilder.add_separator(
            "Data", cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]))
        appbuilder.add_view(
            DruidDatasourceModelView,
            "Druid Datasources",
            label=__("Druid Datasources"),
            category="Data",
            category_label=__("Data"),
            icon="fa-cube",
            menu_cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]),
        )
        appbuilder.add_view(
            DruidClusterModelView,
            name="Druid Clusters",
            label=__("Druid Clusters"),
            icon="fa-cubes",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-database",
            menu_cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]),
        )
        appbuilder.add_view_no_menu(DruidMetricInlineView)
        appbuilder.add_view_no_menu(DruidColumnInlineView)
        appbuilder.add_view_no_menu(Druid)

        appbuilder.add_link(
            "Scan New Datasources",
            label=__("Scan New Datasources"),
            href="/druid/scan_new_datasources/",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-database",
            icon="fa-refresh",
            cond=lambda: bool(self.config["DRUID_IS_ACTIVE"] and self.config[
                "DRUID_METADATA_LINKS_ENABLED"]),
        )
        appbuilder.add_link(
            "Refresh Druid Metadata",
            label=__("Refresh Druid Metadata"),
            href="/druid/refresh_datasources/",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-database",
            icon="fa-cog",
            cond=lambda: bool(self.config["DRUID_IS_ACTIVE"] and self.config[
                "DRUID_METADATA_LINKS_ENABLED"]),
        )
        appbuilder.add_separator(
            "Data", cond=lambda: bool(self.config["DRUID_IS_ACTIVE"]))
Exemple #37
0
    }

    def pre_add(self, db):
        conn = sqla.engine.url.make_url(db.sqlalchemy_uri)
        db.password = conn.password
        conn.password = "******" * 10 if conn.password else None
        db.sqlalchemy_uri = str(conn)  # hides the password

    def pre_update(self, db):
        self.pre_add(db)


appbuilder.add_view(
    DatabaseView,
    "Databases",
    label=__("Databases"),
    icon="fa-database",
    category="Sources",
    category_label=__("Sources"),
    category_icon='fa-database',)


class TableModelView(CaravelModelView, DeleteMixin):  # noqa
    datamodel = SQLAInterface(models.SqlaTable)
    list_columns = [
        'table_link', 'database', 'sql_link', 'is_featured',
        'changed_by_', 'changed_on_']
    order_columns = [
        'table_link', 'database', 'sql_link', 'is_featured', 'changed_on_']
    add_columns = [
        'table_name', 'database', 'schema',
                   YamlExportMixin):  # pylint: disable=too-many-ancestors
    datamodel = SQLAInterface(models.Database)

    add_template = "superset/models/database/add.html"
    edit_template = "superset/models/database/edit.html"
    validators_columns = {"sqlalchemy_uri": [sqlalchemy_uri_form_validator]}

    yaml_dict_key = "databases"

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)


appbuilder.add_link(
    "Import Dashboards",
    label=__("Import Dashboards"),
    href="/superset/import_dashboards",
    icon="fa-cloud-upload",
    category="Manage",
    category_label=__("Manage"),
    category_icon="fa-wrench",
)

appbuilder.add_view(
    DatabaseView,
    "Databases",
    label=__("Databases"),
    icon="fa-database",
    category="Sources",
    category_label=__("Sources"),
    category_icon="fa-database",
Exemple #39
0
    }

    def pre_add(self, cluster):
        security_manager.merge_perm('database_access', cluster.perm)

    def pre_update(self, cluster):
        self.pre_add(cluster)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)


appbuilder.add_view(
    DruidClusterModelView,
    name='Druid Clusters',
    label=__('Druid Clusters'),
    icon='fa-cubes',
    category='Sources',
    category_label=__('Sources'),
    category_icon='fa-database',
)


class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):  # noqa
    datamodel = SQLAInterface(models.DruidDatasource)

    list_title = _('List Druid Datasource')
    show_title = _('Show Druid Datasource')
    add_title = _('Add Druid Datasource')
    edit_title = _('Edit Druid Datasource')
Exemple #40
0
    def init_views(self) -> None:
        #
        # We're doing local imports, as several of them import
        # models which in turn try to import
        # the global Flask app
        #
        # pylint: disable=too-many-locals
        # pylint: disable=too-many-statements
        # pylint: disable=too-many-branches
        from superset.annotation_layers.api import AnnotationLayerRestApi
        from superset.annotation_layers.annotations.api import AnnotationRestApi
        from superset.cachekeys.api import CacheRestApi
        from superset.charts.api import ChartRestApi
        from superset.connectors.druid.views import (
            Druid,
            DruidClusterModelView,
            DruidColumnInlineView,
            DruidDatasourceModelView,
            DruidMetricInlineView,
        )
        from superset.connectors.sqla.views import (
            RowLevelSecurityFiltersModelView,
            SqlMetricInlineView,
            TableColumnInlineView,
            TableModelView,
        )
        from superset.css_templates.api import CssTemplateRestApi
        from superset.dashboards.api import DashboardRestApi
        from superset.databases.api import DatabaseRestApi
        from superset.datasets.api import DatasetRestApi
        from superset.queries.api import QueryRestApi
        from superset.queries.saved_queries.api import SavedQueryRestApi
        from superset.reports.api import ReportScheduleRestApi
        from superset.reports.logs.api import ReportExecutionLogRestApi
        from superset.views.access_requests import AccessRequestsModelView
        from superset.views.alerts import (
            AlertLogModelView,
            AlertModelView,
            AlertObservationModelView,
        )
        from superset.views.annotations import (
            AnnotationLayerModelView,
            AnnotationModelView,
        )
        from superset.views.api import Api
        from superset.views.chart.views import SliceAsync, SliceModelView
        from superset.views.core import Superset
        from superset.views.css_templates import (
            CssTemplateAsyncModelView,
            CssTemplateModelView,
        )
        from superset.views.dashboard.views import (
            Dashboard,
            DashboardModelView,
            DashboardModelViewAsync,
        )
        from superset.views.database.views import (
            CsvToDatabaseView,
            DatabaseView,
            ExcelToDatabaseView,
        )
        from superset.views.datasource import Datasource
        from superset.views.key_value import KV
        from superset.views.log.api import LogRestApi
        from superset.views.log.views import LogModelView
        from superset.views.redirects import R
        from superset.views.schedules import (
            DashboardEmailScheduleView,
            SliceEmailScheduleView,
        )
        from superset.views.sql_lab import (
            SavedQueryView,
            SavedQueryViewApi,
            SqlLab,
            TableSchemaView,
            TabStateView,
        )
        from superset.views.tags import TagView

        #
        # Setup API views
        #
        appbuilder.add_api(AnnotationRestApi)
        appbuilder.add_api(AnnotationLayerRestApi)
        appbuilder.add_api(CacheRestApi)
        appbuilder.add_api(ChartRestApi)
        appbuilder.add_api(CssTemplateRestApi)
        appbuilder.add_api(DashboardRestApi)
        appbuilder.add_api(DatabaseRestApi)
        appbuilder.add_api(DatasetRestApi)
        appbuilder.add_api(QueryRestApi)
        appbuilder.add_api(SavedQueryRestApi)
        if feature_flag_manager.is_feature_enabled("ALERTS_REPORTS"):
            appbuilder.add_api(ReportScheduleRestApi)
            appbuilder.add_api(ReportExecutionLogRestApi)
        #
        # Setup regular views
        #
        appbuilder.add_view(
            AnnotationLayerModelView,
            "Annotation Layers",
            label=__("Annotation Layers"),
            icon="fa-comment",
            category="Manage",
            category_label=__("Manage"),
            category_icon="",
        )
        appbuilder.add_view(
            DatabaseView,
            "Databases",
            label=__("Databases"),
            icon="fa-database",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-database",
        )
        appbuilder.add_link(
            "Datasets",
            label=__("Datasets"),
            href="/tablemodelview/list/?_flt_1_is_sqllab_view=y",
            icon="fa-table",
            category="Data",
            category_label=__("Data"),
            category_icon="fa-table",
        )
        appbuilder.add_separator("Data")
        appbuilder.add_view(
            SliceModelView,
            "Charts",
            label=__("Charts"),
            icon="fa-bar-chart",
            category="",
            category_icon="",
        )
        appbuilder.add_view(
            DashboardModelView,
            "Dashboards",
            label=__("Dashboards"),
            icon="fa-dashboard",
            category="",
            category_icon="",
        )
        appbuilder.add_view(
            CssTemplateModelView,
            "CSS Templates",
            label=__("CSS Templates"),
            icon="fa-css3",
            category="Manage",
            category_label=__("Manage"),
            category_icon="",
        )
        if feature_flag_manager.is_feature_enabled("ROW_LEVEL_SECURITY"):
            appbuilder.add_view(
                RowLevelSecurityFiltersModelView,
                "Row Level Security",
                label=__("Row level security"),
                category="Security",
                category_label=__("Security"),
                icon="fa-lock",
            )

        #
        # Setup views with no menu
        #
        appbuilder.add_view_no_menu(Api)
        appbuilder.add_view_no_menu(CssTemplateAsyncModelView)
        appbuilder.add_view_no_menu(CsvToDatabaseView)
        appbuilder.add_view_no_menu(ExcelToDatabaseView)
        appbuilder.add_view_no_menu(Dashboard)
        appbuilder.add_view_no_menu(DashboardModelViewAsync)
        appbuilder.add_view_no_menu(Datasource)

        if feature_flag_manager.is_feature_enabled("KV_STORE"):
            appbuilder.add_view_no_menu(KV)

        appbuilder.add_view_no_menu(R)
        appbuilder.add_view_no_menu(SavedQueryView)
        appbuilder.add_view_no_menu(SavedQueryViewApi)
        appbuilder.add_view_no_menu(SliceAsync)
        appbuilder.add_view_no_menu(SqlLab)
        appbuilder.add_view_no_menu(SqlMetricInlineView)
        appbuilder.add_view_no_menu(AnnotationModelView)
        appbuilder.add_view_no_menu(Superset)
        appbuilder.add_view_no_menu(TableColumnInlineView)
        appbuilder.add_view_no_menu(TableModelView)
        appbuilder.add_view_no_menu(TableSchemaView)
        appbuilder.add_view_no_menu(TabStateView)

        if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
            appbuilder.add_view_no_menu(TagView)

        #
        # Add links
        #
        appbuilder.add_link(
            "Import Dashboards",
            label=__("Import Dashboards"),
            href="/superset/import_dashboards",
            icon="fa-cloud-upload",
            category="Manage",
            category_label=__("Manage"),
            category_icon="fa-wrench",
        )
        appbuilder.add_link(
            "SQL Editor",
            label=_("SQL Editor"),
            href="/superset/sqllab",
            category_icon="fa-flask",
            icon="fa-flask",
            category="SQL Lab",
            category_label=__("SQL Lab"),
        )
        appbuilder.add_link(
            __("Saved Queries"),
            href="/sqllab/my_queries/",
            icon="fa-save",
            category="SQL Lab",
        )
        appbuilder.add_link(
            "Query Search",
            label=_("Query Search"),
            href="/superset/sqllab#search",
            icon="fa-search",
            category_icon="fa-flask",
            category="SQL Lab",
            category_label=__("SQL Lab"),
        )
        if self.config["CSV_EXTENSIONS"].intersection(
            self.config["ALLOWED_EXTENSIONS"]
        ):
            appbuilder.add_link(
                "Upload a CSV",
                label=__("Upload a CSV"),
                href="/csvtodatabaseview/form",
                icon="fa-upload",
                category="Data",
                category_label=__("Data"),
                category_icon="fa-wrench",
            )
        try:
            import xlrd  # pylint: disable=unused-import

            if self.config["EXCEL_EXTENSIONS"].intersection(
                self.config["ALLOWED_EXTENSIONS"]
            ):
                appbuilder.add_link(
                    "Upload Excel",
                    label=__("Upload Excel"),
                    href="/exceltodatabaseview/form",
                    icon="fa-upload",
                    category="Data",
                    category_label=__("Data"),
                    category_icon="fa-wrench",
                )
        except ImportError:
            pass

        #
        # Conditionally setup log views
        #
        if self.config["FAB_ADD_SECURITY_VIEWS"] and self.config["SUPERSET_LOG_VIEW"]:
            appbuilder.add_api(LogRestApi)
            appbuilder.add_view(
                LogModelView,
                "Action Log",
                label=__("Action Log"),
                category="Security",
                category_label=__("Security"),
                icon="fa-list-ol",
            )

        #
        # Conditionally setup email views
        #
        if self.config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
            appbuilder.add_separator("Manage")
            appbuilder.add_view(
                DashboardEmailScheduleView,
                "Dashboard Email Schedules",
                label=__("Dashboard Emails"),
                category="Manage",
                category_label=__("Manage"),
                icon="fa-search",
            )
            appbuilder.add_view(
                SliceEmailScheduleView,
                "Chart Emails",
                label=__("Chart Email Schedules"),
                category="Manage",
                category_label=__("Manage"),
                icon="fa-search",
            )

        if self.config["ENABLE_ALERTS"]:
            appbuilder.add_view(
                AlertModelView,
                "Alerts",
                label=__("Alerts"),
                category="Manage",
                category_label=__("Manage"),
                icon="fa-exclamation-triangle",
            )
            appbuilder.add_view_no_menu(AlertObservationModelView)
            appbuilder.add_view_no_menu(AlertLogModelView)

        #
        # Conditionally add Access Request Model View
        #
        if self.config["ENABLE_ACCESS_REQUEST"]:
            appbuilder.add_view(
                AccessRequestsModelView,
                "Access requests",
                label=__("Access requests"),
                category="Security",
                category_label=__("Security"),
                icon="fa-table",
            )

        #
        # Conditionally setup Druid Views
        #
        if self.config["DRUID_IS_ACTIVE"]:
            appbuilder.add_separator("Data")
            appbuilder.add_view(
                DruidDatasourceModelView,
                "Druid Datasources",
                label=__("Druid Datasources"),
                category="Data",
                category_label=__("Data"),
                icon="fa-cube",
            )
            appbuilder.add_view(
                DruidClusterModelView,
                name="Druid Clusters",
                label=__("Druid Clusters"),
                icon="fa-cubes",
                category="Data",
                category_label=__("Data"),
                category_icon="fa-database",
            )
            appbuilder.add_view_no_menu(DruidMetricInlineView)
            appbuilder.add_view_no_menu(DruidColumnInlineView)
            appbuilder.add_view_no_menu(Druid)

            if self.config["DRUID_METADATA_LINKS_ENABLED"]:
                appbuilder.add_link(
                    "Scan New Datasources",
                    label=__("Scan New Datasources"),
                    href="/druid/scan_new_datasources/",
                    category="Data",
                    category_label=__("Data"),
                    category_icon="fa-database",
                    icon="fa-refresh",
                )
                appbuilder.add_link(
                    "Refresh Druid Metadata",
                    label=__("Refresh Druid Metadata"),
                    href="/druid/refresh_datasources/",
                    category="Data",
                    category_label=__("Data"),
                    category_icon="fa-database",
                    icon="fa-cog",
                )
            appbuilder.add_separator("Data")
Exemple #41
0
class QueryView(SupersetModelView):
    datamodel = SQLAInterface(Query)
    list_columns = ['user', 'database', 'status', 'start_time', 'end_time']
    label_columns = {
        'user': _('User'),
        'database': _('Database'),
        'status': _('Status'),
        'start_time': _('Start Time'),
        'end_time': _('End Time'),
    }


appbuilder.add_view(
    QueryView,
    'Queries',
    label=__('Queries'),
    category='Manage',
    category_label=__('Manage'),
    icon='fa-search')


class SavedQueryView(SupersetModelView, DeleteMixin):
    datamodel = SQLAInterface(SavedQuery)

    list_title = _('List Saved Query')
    show_title = _('Show Saved Query')
    add_title = _('Add Saved Query')
    edit_title = _('Edit Saved Query')

    list_columns = [
        'label', 'user', 'database', 'schema', 'description',
Exemple #42
0
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                'Metadata refreshed for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in successes]))
            flash(success_msg, 'info')
        if len(failures) > 0:
            failure_msg = _(
                'Unable to retrieve metadata for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in failures]))
            flash(failure_msg, 'danger')

        return redirect('/tablemodelview/list/')


appbuilder.add_view_no_menu(TableModelView)
appbuilder.add_link(
    'Tables',
    label=__('Tables'),
    href='/tablemodelview/list/?_flt_1_is_sqllab_view=y',
    icon='fa-table',
    category='Sources',
    category_label=__('Sources'),
    category_icon='fa-table')

appbuilder.add_separator('Sources')
Exemple #43
0
class NNI_ModelView_Base():
    datamodel = SQLAInterface(NNI)
    conv = GeneralModelConverter(datamodel)
    label_title = 'nni超参搜索'
    check_redirect_list_url = '/nni_modelview/list/'
    help_url = conf.get('HELP_URL', {}).get(datamodel.obj.__tablename__,
                                            '') if datamodel else ''

    base_permissions = [
        'can_add', 'can_edit', 'can_delete', 'can_list', 'can_show'
    ]  # 默认为这些
    base_order = ('id', 'desc')
    base_filters = [["id", NNI_Filter, lambda: []]]  # 设置权限过滤器
    order_columns = ['id']
    list_columns = [
        'project', 'describe_url', 'job_type', 'creator', 'modified', 'run',
        'log'
    ]
    show_columns = [
        'created_by', 'changed_by', 'created_on', 'changed_on', 'job_type',
        'name', 'namespace', 'describe', 'parallel_trial_count',
        'max_trial_count', 'objective_type', 'objective_goal',
        'objective_metric_name', 'objective_additional_metric_names',
        'algorithm_name', 'algorithm_setting', 'parameters_html',
        'trial_spec_html', 'working_dir', 'volume_mount', 'node_selector',
        'image_pull_policy', 'resource_memory', 'resource_cpu', 'resource_gpu',
        'experiment_html', 'alert_status'
    ]

    add_form_query_rel_fields = {
        "project": [["name", Project_Join_Filter, 'org']]
    }
    edit_form_query_rel_fields = add_form_query_rel_fields
    edit_form_extra_fields = {}

    edit_form_extra_fields["alert_status"] = MySelectMultipleField(
        label=_(datamodel.obj.lab('alert_status')),
        widget=Select2ManyWidget(),
        # default=datamodel.obj.alert_status.default.arg,
        choices=[[x, x] for x in [
            'Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Waiting',
            'Terminated'
        ]],
        description="选择通知状态",
    )

    edit_form_extra_fields['name'] = StringField(
        _(datamodel.obj.lab('name')),
        description='英文名(字母、数字、- 组成),最长50个字符',
        widget=BS3TextFieldWidget(),
        validators=[
            DataRequired(),
            Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"),
            Length(1, 54)
        ])
    edit_form_extra_fields['describe'] = StringField(
        _(datamodel.obj.lab('describe')),
        description='中文描述',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['namespace'] = StringField(
        _(datamodel.obj.lab('namespace')),
        description='运行命名空间',
        widget=BS3TextFieldWidget(),
        default=datamodel.obj.namespace.default.arg,
        validators=[DataRequired()])

    edit_form_extra_fields['parallel_trial_count'] = IntegerField(
        _(datamodel.obj.lab('parallel_trial_count')),
        default=datamodel.obj.parallel_trial_count.default.arg,
        description='可并行的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['max_trial_count'] = IntegerField(
        _(datamodel.obj.lab('max_trial_count')),
        default=datamodel.obj.max_trial_count.default.arg,
        description='最大并行的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['max_failed_trial_count'] = IntegerField(
        _(datamodel.obj.lab('max_failed_trial_count')),
        default=datamodel.obj.max_failed_trial_count.default.arg,
        description='最大失败的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['objective_type'] = SelectField(
        _(datamodel.obj.lab('objective_type')),
        default=datamodel.obj.objective_type.default.arg,
        description='目标函数类型(和自己代码中对应)',
        widget=Select2Widget(),
        choices=[['maximize', 'maximize'], ['minimize', 'minimize']],
        validators=[DataRequired()])

    edit_form_extra_fields['objective_goal'] = FloatField(
        _(datamodel.obj.lab('objective_goal')),
        default=datamodel.obj.objective_goal.default.arg,
        description='目标门限',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['objective_metric_name'] = StringField(
        _(datamodel.obj.lab('objective_metric_name')),
        default=NNI.objective_metric_name.default.arg,
        description='目标函数(和自己代码中对应)',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['objective_additional_metric_names'] = StringField(
        _(datamodel.obj.lab('objective_additional_metric_names')),
        default=datamodel.obj.objective_additional_metric_names.default.arg,
        description='其他目标函数(和自己代码中对应)',
        widget=BS3TextFieldWidget())

    algorithm_name_choices = {
        'TPE': "TPE",
        'Random': "随机搜索",
        "Anneal": "退火算法",
        "Evolution": "进化算法",
        "SMAC": "SMAC",
        "BatchTuner": "批量调参器",
        "GridSearch": "网格搜索",
        "Hyperband": "Hyperband",
        "NetworkMorphism": "Network Morphism",
        "MetisTuner": "Metis Tuner",
        "BOHB": "BOHB Advisor",
        "GPTuner": "GP Tuner",
        "PPOTuner": "PPO Tuner",
        "PBTTuner": "PBT Tuner"
    }

    algorithm_name_choices = list(algorithm_name_choices.items())

    edit_form_extra_fields['algorithm_name'] = SelectField(
        _(datamodel.obj.lab('algorithm_name')),
        default=datamodel.obj.algorithm_name.default.arg,
        description='搜索算法',
        widget=Select2Widget(),
        choices=algorithm_name_choices,
        validators=[DataRequired()])
    edit_form_extra_fields['algorithm_setting'] = StringField(
        _(datamodel.obj.lab('algorithm_setting')),
        default=datamodel.obj.algorithm_setting.default.arg,
        widget=BS3TextFieldWidget(),
        description='搜索算法配置')

    edit_form_extra_fields['parameters_demo'] = StringField(
        _(datamodel.obj.lab('parameters_demo')),
        description='搜索参数示例,标准json格式,注意:所有整型、浮点型都写成字符串型',
        widget=MyCodeArea(code=core.nni_parameters_demo()),
    )
    edit_form_extra_fields['parameters'] = StringField(
        _(datamodel.obj.lab('parameters')),
        default=datamodel.obj.parameters.default.arg,
        description='搜索参数,注意:所有整型、浮点型都写成字符串型',
        widget=MyBS3TextAreaFieldWidget(rows=10),
        validators=[DataRequired()])
    edit_form_extra_fields['node_selector'] = StringField(
        _(datamodel.obj.lab('node_selector')),
        description="部署task所在的机器(目前无需填写)",
        default=datamodel.obj.node_selector.default.arg,
        widget=BS3TextFieldWidget())
    edit_form_extra_fields['working_dir'] = StringField(
        _(datamodel.obj.lab('working_dir')),
        description="代码所在目录,nni代码、配置和log都将在/mnt/${your_name}/nni/目录下进行",
        default=datamodel.obj.working_dir.default.arg,
        widget=BS3TextFieldWidget())
    edit_form_extra_fields['image_pull_policy'] = SelectField(
        _(datamodel.obj.lab('image_pull_policy')),
        description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)",
        widget=Select2Widget(),
        choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']])
    edit_form_extra_fields['volume_mount'] = StringField(
        _(datamodel.obj.lab('volume_mount')),
        description=
        '外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2',
        default=datamodel.obj.volume_mount.default.arg,
        widget=BS3TextFieldWidget())
    edit_form_extra_fields['resource_memory'] = StringField(
        _(datamodel.obj.lab('resource_memory')),
        default=datamodel.obj.resource_memory.default.arg,
        description='内存的资源使用限制(每个测试实例),示例:1G,20G',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])
    edit_form_extra_fields['resource_cpu'] = StringField(
        _(datamodel.obj.lab('resource_cpu')),
        default=datamodel.obj.resource_cpu.default.arg,
        description='cpu的资源使用限制(每个测试实例)(单位:核),示例:2',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()])

    @pysnooper.snoop()
    def set_column(self, nni=None):
        # 对编辑进行处理
        request_data = request.args.to_dict()
        job_type = request_data.get('job_type', '')
        if nni:
            job_type = nni.job_type

        job_type_choices = ['', 'Job']
        job_type_choices = [[job_type_choice, job_type_choice]
                            for job_type_choice in job_type_choices]

        if nni:
            self.edit_form_extra_fields['job_type'] = SelectField(
                _(self.datamodel.obj.lab('job_type')),
                description="超参搜索的任务类型",
                choices=job_type_choices,
                widget=MySelect2Widget(extra_classes="readonly",
                                       value=job_type),
                validators=[DataRequired()])
        else:
            self.edit_form_extra_fields['job_type'] = SelectField(
                _(self.datamodel.obj.lab('job_type')),
                description="超参搜索的任务类型",
                widget=MySelect2Widget(new_web=True, value=job_type),
                choices=job_type_choices,
                validators=[DataRequired()])

        self.edit_form_extra_fields['tf_worker_num'] = IntegerField(
            _(self.datamodel.obj.lab('tf_worker_num')),
            default=json.loads(nni.job_json).get('tf_worker_num', 3)
            if nni and nni.job_json else 3,
            description='工作节点数目',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['tf_worker_image'] = StringField(
            _(self.datamodel.obj.lab('tf_worker_image')),
            default=json.loads(nni.job_json).get(
                'tf_worker_image', conf.get('NNI_TFJOB_DEFAULT_IMAGE', '')) if
            nni and nni.job_json else conf.get('NNI_TFJOB_DEFAULT_IMAGE', ''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['tf_worker_command'] = StringField(
            _(self.datamodel.obj.lab('tf_worker_command')),
            default=json.loads(nni.job_json).get('tf_worker_command',
                                                 'python xx.py')
            if nni and nni.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['job_worker_image'] = StringField(
            _(self.datamodel.obj.lab('job_worker_image')),
            default=json.loads(nni.job_json).get(
                'job_worker_image', conf.get('NNI_JOB_DEFAULT_IMAGE', ''))
            if nni and nni.job_json else conf.get('NNI_JOB_DEFAULT_IMAGE', ''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['job_worker_command'] = StringField(
            _(self.datamodel.obj.lab('job_worker_command')),
            default=json.loads(nni.job_json).get('job_worker_command',
                                                 'python xx.py')
            if nni and nni.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=MyBS3TextAreaFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['pytorch_worker_num'] = IntegerField(
            _(self.datamodel.obj.lab('pytorch_worker_num')),
            default=json.loads(nni.job_json).get('pytorch_worker_num', 3)
            if nni and nni.job_json else 3,
            description='工作节点数目',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['pytorch_worker_image'] = StringField(
            _(self.datamodel.obj.lab('pytorch_worker_image')),
            default=json.loads(nni.job_json).get(
                'pytorch_worker_image',
                conf.get('NNI_PYTORCHJOB_DEFAULT_IMAGE', '')) if nni
            and nni.job_json else conf.get('NNI_PYTORCHJOB_DEFAULT_IMAGE', ''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['pytorch_master_command'] = StringField(
            _(self.datamodel.obj.lab('pytorch_master_command')),
            default=json.loads(nni.job_json).get('pytorch_master_command',
                                                 'python xx.py')
            if nni and nni.job_json else 'python xx.py',
            description='master节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.edit_form_extra_fields['pytorch_worker_command'] = StringField(
            _(self.datamodel.obj.lab('pytorch_worker_command')),
            default=json.loads(nni.job_json).get('pytorch_worker_command',
                                                 'python xx.py')
            if nni and nni.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])

        self.edit_columns = [
            'job_type', 'project', 'name', 'namespace', 'describe',
            'parallel_trial_count', 'max_trial_count', 'objective_type',
            'objective_goal', 'objective_metric_name',
            'objective_additional_metric_names', 'algorithm_name',
            'algorithm_setting', 'parameters_demo', 'parameters'
        ]
        self.edit_fieldsets = [(
            lazy_gettext('common'),
            {
                "fields": copy.deepcopy(self.edit_columns),
                "expanded": True
            },
        )]

        if job_type == 'TFJob':
            group_columns = [
                'tf_worker_num', 'tf_worker_image', 'tf_worker_command'
            ]
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {
                    "fields": group_columns,
                    "expanded": True
                },
            ))
            for column in group_columns:
                self.edit_columns.append(column)
        if job_type == 'Job':
            group_columns = ['job_worker_image', 'job_worker_command']
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {
                    "fields": group_columns,
                    "expanded": True
                },
            ))
            for column in group_columns:
                self.edit_columns.append(column)
        if job_type == 'PyTorchJob':
            group_columns = [
                'pytorch_worker_num', 'pytorch_worker_image',
                'pytorch_master_command', 'pytorch_worker_command'
            ]
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {
                    "fields": group_columns,
                    "expanded": True
                },
            ))
            for column in group_columns:
                self.edit_columns.append(column)

        if job_type == 'XGBoostJob':
            group_columns = [
                'pytorchjob_worker_image', 'pytorchjob_worker_command'
            ]
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {
                    "fields": group_columns,
                    "expanded": True
                },
            ))
            for column in group_columns:
                self.edit_columns.append(column)

        task_column = ['working_dir', 'resource_memory', 'resource_cpu']
        self.edit_fieldsets.append((
            lazy_gettext('task args'),
            {
                "fields": task_column,
                "expanded": True
            },
        ))
        for column in task_column:
            self.edit_columns.append(column)

        self.edit_fieldsets.append((
            lazy_gettext('run experiment'),
            {
                "fields": ['alert_status'],
                "expanded": True
            },
        ))

        self.edit_columns.append('alert_status')

        self.add_form_extra_fields = self.edit_form_extra_fields
        self.add_fieldsets = self.edit_fieldsets
        self.add_columns = self.edit_columns

    pre_add_get = set_column
    pre_update_get = set_column

    # 处理form请求
    def process_form(self, form, is_created):
        # from flask_appbuilder.forms import DynamicForm
        if 'parameters_demo' in form._fields:
            del form._fields['parameters_demo']  # 不处理这个字段

    # @pysnooper.snoop()
    def deploy_nni_service(self, nni, command):
        image_secrets = conf.get('HUBSECRET', [])
        user_hubsecrets = db.session.query(Repository.hubsecret).filter(
            Repository.created_by_fk == g.user.id).all()
        if user_hubsecrets:
            for hubsecret in user_hubsecrets:
                if hubsecret[0] not in image_secrets:
                    image_secrets.append(hubsecret[0])

        from myapp.utils.py.py_k8s import K8s
        k8s_client = K8s(nni.project.cluster.get('KUBECONFIG', ''))
        namespace = conf.get('KATIB_NAMESPACE')
        run_id = 'nni-' + nni.name

        try:
            nni_deploy = k8s_client.AppsV1Api.read_namespaced_deployment(
                name=nni.name, namespace=namespace)
            if nni_deploy:
                print('exist nni deploy')
                k8s_client.AppsV1Api.delete_namespaced_deployment(
                    name=nni.name, namespace=namespace)
                # return
        except Exception as e:
            print(e)

        volume_mount = nni.volume_mount + ",/usr/share/zoneinfo/Asia/Shanghai(hostpath):/etc/localtime"
        labels = {
            "nni": nni.name,
            "username": nni.created_by.username,
            'run-id': run_id
        }

        k8s_client.create_debug_pod(
            namespace=namespace,
            name=nni.name,
            labels=labels,
            command=command,
            args=None,
            volume_mount=volume_mount,
            working_dir='/mnt/%s' % nni.created_by.username,
            node_selector=nni.get_node_selector(),
            resource_memory='2G',
            resource_cpu='2',
            resource_gpu='0',
            image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
            image_pull_secrets=image_secrets,
            image=conf.get('NNI_IMAGES',
                           json.loads(nni.job_json).get('job_worker_image')),
            hostAliases=conf.get('HOSTALIASES', ''),
            env=None,
            privileged=False,
            accounts='nni',
            username=nni.created_by.username)

        k8s_client.create_service(namespace=namespace,
                                  name=nni.name,
                                  username=nni.created_by.username,
                                  ports=[8888],
                                  selector=labels)

        host = nni.project.cluster.get('NNI_DOMAIN', request.host)
        if not host:
            host = request.host
        if ':' in host:
            host = host[:host.rindex(':')]  # 如果捕获到端口号,要去掉
        vs_json = {
            "apiVersion": "networking.istio.io/v1alpha3",
            "kind": "VirtualService",
            "metadata": {
                "name": nni.name,
                "namespace": namespace
            },
            "spec": {
                "gateways": ["kubeflow/kubeflow-gateway"],
                "hosts": ["*" if core.checkip(host) else host],
                "http": [{
                    "match": [{
                        "uri": {
                            "prefix": "/nni/%s//" % nni.name
                        }
                    }, {
                        "uri": {
                            "prefix": "/nni/%s/" % nni.name
                        }
                    }],
                    "rewrite": {
                        "uri": "/nni/%s/" % nni.name
                    },
                    "route": [{
                        "destination": {
                            "host":
                            "%s.%s.svc.cluster.local" % (nni.name, namespace),
                            "port": {
                                "number": 8888
                            }
                        }
                    }],
                    "timeout":
                    "300s"
                }]
            }
        }
        crd_info = conf.get('CRD_INFO')['virtualservice']
        k8s_client.delete_istio_ingress(namespace=namespace, name=nni.name)

        k8s_client.create_crd(group=crd_info['group'],
                              version=crd_info['version'],
                              plural=crd_info['plural'],
                              namespace=namespace,
                              body=vs_json)

    # 生成实验
    # @pysnooper.snoop()
    @expose('/run/<nni_id>', methods=['GET', 'POST'])
    # @pysnooper.snoop()
    def run(self, nni_id):
        nni = db.session.query(NNI).filter(NNI.id == nni_id).first()

        image_secrets = conf.get('HUBSECRET', [])
        user_hubsecrets = db.session.query(Repository.hubsecret).filter(
            Repository.created_by_fk == g.user.id).all()
        if user_hubsecrets:
            for hubsecret in user_hubsecrets:
                if hubsecret[0] not in image_secrets:
                    image_secrets.append(hubsecret[0])
        image_secrets = str(image_secrets)

        trial_template = f'''
apiVersion: frameworkcontroller.microsoft.com/v1
kind: Framework
metadata:
  name: {nni.name}
  namespace: {nni.namespace}
spec:
  executionType: Start
  retryPolicy:
    fancyRetryPolicy: true
    maxRetryCount: 2
  taskRoles:
  - name: worker
    taskNumber: 1
    frameworkAttemptCompletionPolicy:
      minFailedTaskCount: 1
      minSucceededTaskCount: 3
    task:
      retryPolicy:
        fancyRetryPolicy: false
        maxRetryCount: 0
      podGracefulDeletionTimeoutSec: 1800
      pod:
        spec:
          restartPolicy: Never
          hostNetwork: false
          imagePullSecrets: {image_secrets}

          containers:
          - name: {nni.name}
            image: {json.loads(nni.job_json).get("job_worker_image",'')}
            command: {json.loads(nni.job_json).get("job_worker_command",'').split(' ')}
            ports:
            - containerPort: 5001
            volumeMounts:
            - name: frameworkbarrier-volume
              mountPath: /mnt/frameworkbarrier
            - name: data-volume
              mountPath: /tmp/mount  
          serviceAccountName: frameworkbarrier
          initContainers:
          - name: frameworkbarrier
            image: frameworkcontroller/frameworkbarrier
            imagePullPolicy: IfNotPresent
            volumeMounts:
            - name: frameworkbarrier-volume
              mountPath: /mnt/frameworkbarrier
          volumes:
          - name: frameworkbarrier-volume
            emptyDir: {{}}
          - name: data-volume
            hostPath:
              path: {conf.get('WORKSPACE_HOST_PATH','')}/{nni.created_by.username}/nni/{nni.name}
        '''

        controll_yaml = f'''
authorName: default
experimentName: {nni.name}
trialConcurrency: {nni.parallel_trial_count}
maxExecDuration: {nni.maxExecDuration}s
maxTrialNum: {nni.max_trial_count}
logLevel: info
logCollection: none
#choice: local, remote, pai, kubeflow
trainingServicePlatform: frameworkcontroller
searchSpacePath: /mnt/{nni.created_by.username}/nni/{nni.name}/search_space.json
#choice: true, false
useAnnotation: false
tuner:
  #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
  builtinTunerName: {nni.algorithm_name}
trial:
  codeDir: {nni.working_dir}  
frameworkcontrollerConfig:
  namespace: {conf.get('KATIB_NAMESPACE','katib')}
  storage: pvc
  configPath: /mnt/{nni.created_by.username}/nni/{nni.name}/trial_template.yaml  
  pvc: 
    path: "/mnt/{nni.created_by.username}/nni/{nni.name}/"
'''

        code_dir = "%s/%s/nni/%s" % (conf.get(
            'WORKSPACE_HOST_PATH', ''), nni.created_by.username, nni.name)
        if not os.path.exists(code_dir):
            os.makedirs(code_dir)
        trial_template_path = os.path.join(code_dir, 'trial_template.yaml')
        file = open(trial_template_path, mode='w')
        file.write(trial_template)
        file.close()

        controll_template_path = os.path.join(code_dir,
                                              'controll_template.yaml')
        file = open(controll_template_path, mode='w')
        file.write(controll_yaml)
        file.close()

        searchSpacePath = os.path.join(code_dir, 'search_space.json')
        file = open(searchSpacePath, mode='w')
        file.write(nni.parameters)
        file.close()

        flash('nni服务部署完成', category='success')

        # 执行启动命令
        # command = ['bash','-c','mkdir -p /nni/nni_node/static/nni/%s && cp -r /nni/nni_node/static/* /nni/nni_node/static/nni/%s/ ; nnictl create --config /mnt/%s/nni/%s/controll_template.yaml -p 8888 --foreground --url_prefix nni/%s'%(nni.name,nni.name,nni.created_by.username,nni.name,nni.name)]
        # command = ['bash', '-c','nnictl create --config /mnt/%s/nni/%s/controll_template.yaml -p 8888 --foreground' % (nni.created_by.username, nni.name)]
        command = [
            'bash', '-c',
            'nnictl create --config /mnt/%s/nni/%s/controll_template.yaml -p 8888 --foreground --url_prefix nni/%s'
            % (nni.created_by.username, nni.name, nni.name)
        ]

        print(command)
        self.deploy_nni_service(nni, command)

        return redirect('/nni_modelview/list/')

    # @pysnooper.snoop(watch_explode=())
    def merge_trial_spec(self, item):

        image_secrets = conf.get('HUBSECRET', [])
        user_hubsecrets = db.session.query(Repository.hubsecret).filter(
            Repository.created_by_fk == g.user.id).all()
        if user_hubsecrets:
            for hubsecret in user_hubsecrets:
                if hubsecret[0] not in image_secrets:
                    image_secrets.append(hubsecret[0])

        image_secrets = [{"name": hubsecret} for hubsecret in image_secrets]

        item.job_json = {}
        if item.job_type == 'TFJob':
            item.trial_spec = core.merge_tfjob_experiment_template(
                worker_num=item.tf_worker_num,
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.tf_worker_image,
                image_secrets=image_secrets,
                hostAliases=conf.get('HOSTALIASES', ''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                command=item.tf_worker_command)
            item.job_json = {
                "tf_worker_num": item.tf_worker_num,
                "tf_worker_image": item.tf_worker_image,
                "tf_worker_command": item.tf_worker_command,
            }
        if item.job_type == 'Job':
            item.trial_spec = core.merge_job_experiment_template(
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.job_worker_image,
                image_secrets=image_secrets,
                hostAliases=conf.get('HOSTALIASES', ''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                command=item.job_worker_command)

            item.job_json = {
                "job_worker_image": item.job_worker_image,
                "job_worker_command": item.job_worker_command,
            }
        if item.job_type == 'PyTorchJob':
            item.trial_spec = core.merge_pytorchjob_experiment_template(
                worker_num=item.pytorch_worker_num,
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.pytorch_worker_image,
                image_secrets=image_secrets,
                hostAliases=conf.get('HOSTALIASES', ''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                master_command=item.pytorch_master_command,
                worker_command=item.pytorch_worker_command)

            item.job_json = {
                "pytorch_worker_num": item.pytorch_worker_num,
                "pytorch_worker_image": item.pytorch_worker_image,
                "pytorch_master_command": item.pytorch_master_command,
                "pytorch_worker_command": item.pytorch_worker_command,
            }
        item.job_json = json.dumps(item.job_json, indent=4, ensure_ascii=False)

    # 检验参数是否有效
    # @pysnooper.snoop()
    def validate_parameters(self, parameters, algorithm):
        return parameters

    @expose("/log/<nni_id>", methods=["GET", "POST"])
    def log_task(self, nni_id):
        nni = db.session.query(NNI).filter_by(id=nni_id).first()
        from myapp.utils.py.py_k8s import K8s
        k8s = K8s(nni.project.cluster.get('KUBECONFIG', ''))
        namespace = conf.get('KATIB_NAMESPACE')
        pod = k8s.get_pods(namespace=namespace, pod_name=nni.name)
        if pod:
            pod = pod[0]
            return redirect("/myapp/web/log/%s/%s/%s" %
                            (nni.project.cluster['NAME'], namespace, nni.name))

        flash("未检测到当前搜索正在运行的容器", category='success')
        return redirect('/nni_modelview/list/')

    # @pysnooper.snoop()
    def pre_add(self, item):

        if item.job_type is None:
            raise MyappException("Job type is mandatory")

        if not item.volume_mount:
            item.volume_mount = item.project.volume_mount

        core.validate_json(item.parameters)
        item.parameters = self.validate_parameters(item.parameters,
                                                   item.algorithm_name)

        item.resource_memory = core.check_resource_memory(
            item.resource_memory,
            self.src_item_json.get('resource_memory', None)
            if self.src_item_json else None)
        item.resource_cpu = core.check_resource_cpu(
            item.resource_cpu,
            self.src_item_json.get('resource_cpu', None)
            if self.src_item_json else None)
        self.merge_trial_spec(item)
        # self.make_experiment(item)

    def pre_update(self, item):
        self.pre_add(item)

    @action("copy",
            __("Copy NNI Experiment"),
            confirmation=__('Copy NNI Experiment'),
            icon="fa-copy",
            multiple=True,
            single=False)
    def copy(self, nnis):
        if not isinstance(nnis, list):
            nnis = [nnis]
        for nni in nnis:
            new_nni = nni.clone()
            new_nni.name = new_nni.name + "-copy"
            new_nni.describe = new_nni.describe + "-copy"
            new_nni.created_on = datetime.datetime.now()
            new_nni.changed_on = datetime.datetime.now()
            db.session.add(new_nni)
            db.session.commit()

        return redirect(request.referrer)
Exemple #44
0
class AthenaEngineSpec(BaseEngineSpec):
    engine = "awsathena"
    engine_name = "Amazon Athena"
    allows_escaped_colons = False

    _time_grain_expressions = {
        None:
        "{col}",
        "PT1S":
        "date_trunc('second', CAST({col} AS TIMESTAMP))",
        "PT1M":
        "date_trunc('minute', CAST({col} AS TIMESTAMP))",
        "PT1H":
        "date_trunc('hour', CAST({col} AS TIMESTAMP))",
        "P1D":
        "date_trunc('day', CAST({col} AS TIMESTAMP))",
        "P1W":
        "date_trunc('week', CAST({col} AS TIMESTAMP))",
        "P1M":
        "date_trunc('month', CAST({col} AS TIMESTAMP))",
        "P3M":
        "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
        "P1Y":
        "date_trunc('year', CAST({col} AS TIMESTAMP))",
        "P1W/1970-01-03T00:00:00Z":
        "date_add('day', 5, date_trunc('week', \
                                    date_add('day', 1, CAST({col} AS TIMESTAMP))))",
        "1969-12-28T00:00:00Z/P1W":
        "date_add('day', -1, date_trunc('week', \
                                    date_add('day', 1, CAST({col} AS TIMESTAMP))))",
    }

    custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[
        str, Any]]] = {
            SYNTAX_ERROR_REGEX: (
                __("Please check your query for syntax errors at or "
                   'near "%(syntax_error)s". Then, try running your query again.'
                   ),
                SupersetErrorType.SYNTAX_ERROR,
                {},
            ),
        }

    @classmethod
    def convert_dttm(
            cls,
            target_type: str,
            dttm: datetime,
            db_extra: Optional[Dict[str, Any]] = None) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"from_iso8601_date('{dttm.date().isoformat()}')"
        if tt == utils.TemporalType.TIMESTAMP:
            datetime_formatted = dttm.isoformat(timespec="microseconds")
            return f"""from_iso8601_timestamp('{datetime_formatted}')"""
        return None

    @classmethod
    def epoch_to_dttm(cls) -> str:
        return "from_unixtime({col})"

    @staticmethod
    def _mutate_label(label: str) -> str:
        """
        Athena only supports lowercase column names and aliases.

        :param label: Expected expression label
        :return: Conditionally mutated label
        """
        return label.lower()
Exemple #45
0
class BigQueryEngineSpec(BaseEngineSpec):
    """Engine spec for Google's BigQuery

    As contributed by @mxmzdlv on issue #945"""

    engine = "bigquery"
    engine_name = "Google BigQuery"
    max_column_name_length = 128

    parameters_schema = BigQueryParametersSchema()
    default_driver = "bigquery"
    sqlalchemy_uri_placeholder = "bigquery://{project_id}"

    # BigQuery doesn't maintain context when running multiple statements in the
    # same cursor, so we need to run all statements at once
    run_multiple_statements_as_one = True

    """
    https://www.python.org/dev/peps/pep-0249/#arraysize
    raw_connections bypass the pybigquery query execution context and deal with
    raw dbapi connection directly.
    If this value is not set, the default value is set to 1, as described here,
    https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor

    The default value of 5000 is derived from the pybigquery.
    https://github.com/mxmzdlv/pybigquery/blob/d214bb089ca0807ca9aaa6ce4d5a01172d40264e/pybigquery/sqlalchemy_bigquery.py#L102
    """
    arraysize = 5000

    _date_trunc_functions = {
        "DATE": "DATE_TRUNC",
        "DATETIME": "DATETIME_TRUNC",
        "TIME": "TIME_TRUNC",
        "TIMESTAMP": "TIMESTAMP_TRUNC",
    }

    _time_grain_expressions = {
        None: "{col}",
        "PT1S": "{func}({col}, SECOND)",
        "PT1M": "{func}({col}, MINUTE)",
        "PT5M": "CAST(TIMESTAMP_SECONDS("
        "5*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 5*60)"
        ") AS {type})",
        "PT10M": "CAST(TIMESTAMP_SECONDS("
        "10*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 10*60)"
        ") AS {type})",
        "PT15M": "CAST(TIMESTAMP_SECONDS("
        "15*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 15*60)"
        ") AS {type})",
        "PT0.5H": "CAST(TIMESTAMP_SECONDS("
        "30*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 30*60)"
        ") AS {type})",
        "PT1H": "{func}({col}, HOUR)",
        "P1D": "{func}({col}, DAY)",
        "P1W": "{func}({col}, WEEK)",
        "P1M": "{func}({col}, MONTH)",
        "P0.25Y": "{func}({col}, QUARTER)",
        "P1Y": "{func}({col}, YEAR)",
    }

    custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
        CONNECTION_DATABASE_PERMISSIONS_REGEX: (
            __(
                "We were unable to connect to your database. Please "
                "confirm that your service account has the Viewer "
                "and Job User roles on the project."
            ),
            SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR,
            {},
        ),
        TABLE_DOES_NOT_EXIST_REGEX: (
            __(
                'The table "%(table)s" does not exist. '
                "A valid table must be used to run this query.",
            ),
            SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR,
            {},
        ),
        COLUMN_DOES_NOT_EXIST_REGEX: (
            __('We can\'t seem to resolve column "%(column)s" at line %(location)s.'),
            SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
            {},
        ),
        SCHEMA_DOES_NOT_EXIST_REGEX: (
            __(
                'The schema "%(schema)s" does not exist. '
                "A valid schema must be used to run this query."
            ),
            SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR,
            {},
        ),
        SYNTAX_ERROR_REGEX: (
            __(
                "Please check your query for syntax errors at or near "
                '"%(syntax_error)s". Then, try running your query again.'
            ),
            SupersetErrorType.SYNTAX_ERROR,
            {},
        ),
    }

    @classmethod
    def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"CAST('{dttm.date().isoformat()}' AS DATE)"
        if tt == utils.TemporalType.DATETIME:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
        if tt == utils.TemporalType.TIME:
            return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
        if tt == utils.TemporalType.TIMESTAMP:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
        return None

    @classmethod
    def fetch_data(
        cls, cursor: Any, limit: Optional[int] = None
    ) -> List[Tuple[Any, ...]]:
        data = super().fetch_data(cursor, limit)
        # Support type BigQuery Row, introduced here PR #4071
        # google.cloud.bigquery.table.Row
        if data and type(data[0]).__name__ == "Row":
            data = [r.values() for r in data]  # type: ignore
        return data

    @staticmethod
    def _mutate_label(label: str) -> str:
        """
        BigQuery field_name should start with a letter or underscore and contain only
        alphanumeric characters. Labels that start with a number are prefixed with an
        underscore. Any unsupported characters are replaced with underscores and an
        md5 hash is added to the end of the label to avoid possible collisions.

        :param label: Expected expression label
        :return: Conditionally mutated label
        """
        label_hashed = "_" + md5_sha_from_str(label)

        # if label starts with number, add underscore as first character
        label_mutated = "_" + label if re.match(r"^\d", label) else label

        # replace non-alphanumeric characters with underscores
        label_mutated = re.sub(r"[^\w]+", "_", label_mutated)
        if label_mutated != label:
            # add first 5 chars from md5 hash to label to avoid possible collisions
            label_mutated += label_hashed[:6]

        return label_mutated

    @classmethod
    def _truncate_label(cls, label: str) -> str:
        """BigQuery requires column names start with either a letter or
        underscore. To make sure this is always the case, an underscore is prefixed
        to the md5 hash of the original label.

        :param label: expected expression label
        :return: truncated label
        """
        return "_" + md5_sha_from_str(label)

    @classmethod
    def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Normalizes indexes for more consistency across db engines

        :param indexes: Raw indexes as returned by SQLAlchemy
        :return: cleaner, more aligned index definition
        """
        normalized_idxs = []
        # Fixing a bug/behavior observed in pybigquery==0.4.15 where
        # the index's `column_names` == [None]
        # Here we're returning only non-None indexes
        for ix in indexes:
            column_names = ix.get("column_names") or []
            ix["column_names"] = [col for col in column_names if col is not None]
            if ix["column_names"]:
                normalized_idxs.append(ix)
        return normalized_idxs

    @classmethod
    def extra_table_metadata(
        cls, database: "Database", table_name: str, schema_name: str
    ) -> Dict[str, Any]:
        indexes = database.get_indexes(table_name, schema_name)
        if not indexes:
            return {}
        partitions_columns = [
            index.get("column_names", [])
            for index in indexes
            if index.get("name") == "partition"
        ]
        cluster_columns = [
            index.get("column_names", [])
            for index in indexes
            if index.get("name") == "clustering"
        ]
        return {
            "partitions": {"cols": partitions_columns},
            "clustering": {"cols": cluster_columns},
        }

    @classmethod
    def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
        """
        BigQuery dialect requires us to not use backtick in the fieldname which are
        nested.
        Using literal_column handles that issue.
        https://docs.sqlalchemy.org/en/latest/core/tutorial.html#using-more-specific-text-with-table-literal-column-and-column
        Also explicility specifying column names so we don't encounter duplicate
        column names in the result.
        """
        return [
            literal_column(c["name"]).label(c["name"].replace(".", "__")) for c in cols
        ]

    @classmethod
    def epoch_to_dttm(cls) -> str:
        return "TIMESTAMP_SECONDS({col})"

    @classmethod
    def epoch_ms_to_dttm(cls) -> str:
        return "TIMESTAMP_MILLIS({col})"

    @classmethod
    def df_to_sql(
        cls,
        database: "Database",
        table: Table,
        df: pd.DataFrame,
        to_sql_kwargs: Dict[str, Any],
    ) -> None:
        """
        Upload data from a Pandas DataFrame to a database.

        Calls `pandas_gbq.DataFrame.to_gbq` which requires `pandas_gbq` to be installed.

        Note this method does not create metadata for the table.

        :param database: The database to upload the data to
        :param table: The table to upload the data to
        :param df: The dataframe with data to be uploaded
        :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
        """

        try:
            import pandas_gbq
            from google.oauth2 import service_account
        except ImportError:
            raise Exception(
                "Could not import libraries `pandas_gbq` or `google.oauth2`, which are "
                "required to be installed in your environment in order "
                "to upload data to BigQuery"
            )

        if not table.schema:
            raise Exception("The table schema must be defined")

        engine = cls.get_engine(database)
        to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host}

        # Add credentials if they are set on the SQLAlchemy dialect.
        creds = engine.dialect.credentials_info

        if creds:
            to_gbq_kwargs[
                "credentials"
            ] = service_account.Credentials.from_service_account_info(creds)

        # Only pass through supported kwargs.
        supported_kwarg_keys = {"if_exists"}

        for key in supported_kwarg_keys:
            if key in to_sql_kwargs:
                to_gbq_kwargs[key] = to_sql_kwargs[key]

        pandas_gbq.to_gbq(df, **to_gbq_kwargs)

    @classmethod
    def build_sqlalchemy_uri(
        cls, _: BigQueryParametersType, encrypted_extra: Optional[Dict[str, Any]] = None
    ) -> str:
        if encrypted_extra:
            project_id = encrypted_extra.get("credentials_info", {}).get("project_id")

        if project_id:
            return f"{cls.engine}+{cls.default_driver}://{project_id}"

        raise SupersetGenericDBErrorException(
            message="Big Query encrypted_extra is not available.",
        )

    @classmethod
    def get_parameters_from_uri(
        cls, _: str, encrypted_extra: Optional[Dict[str, str]] = None
    ) -> Any:
        # BigQuery doesn't have parameters
        if encrypted_extra:
            return encrypted_extra

        raise SupersetGenericDBErrorException(
            message="Big Query encrypted_extra is not available.",
        )

    @classmethod
    def validate_parameters(
        cls, parameters: BigQueryParametersType  # pylint: disable=unused-argument
    ) -> List[SupersetError]:
        return []

    @classmethod
    def parameters_json_schema(cls) -> Any:
        """
        Return configuration parameters as OpenAPI.
        """
        if not cls.parameters_schema:
            return None

        spec = APISpec(
            title="Database Parameters",
            version="1.0.0",
            openapi_version="3.0.0",
            plugins=[ma_plugin],
        )

        ma_plugin.init_spec(spec)
        ma_plugin.converter.add_attribute_function(encrypted_field_properties)
        spec.components.schema(cls.__name__, schema=cls.parameters_schema)
        return spec.to_dict()["components"]["schemas"][cls.__name__]
class QueryView(SupersetModelView):
    datamodel = SQLAInterface(Query)
    list_columns = ['user', 'database', 'status', 'start_time', 'end_time']
    label_columns = {
        'user': _('User'),
        'database': _('Database'),
        'status': _('Status'),
        'start_time': _('Start Time'),
        'end_time': _('End Time'),
    }

appbuilder.add_view(
    QueryView,
    "Queries",
    label=__("Queries"),
    category="Manage",
    category_label=__("Manage"),
    icon="fa-search")


class SavedQueryView(SupersetModelView, DeleteMixin):
    datamodel = SQLAInterface(SavedQuery)

    list_title = _('List Saved Query')
    show_title = _('Show Saved Query')
    add_title = _('Add Saved Query')
    edit_title = _('Edit Saved Query')

    list_columns = [
        'label', 'user', 'database', 'schema', 'description',
Exemple #47
0
class DeleteMixin(object):
    def _delete(self, pk):
        """
            Delete function logic, override to implement diferent logic
            deletes the record with primary_key = pk

            :param pk:
                record primary key to delete
        """
        item = self.datamodel.get(pk, self._base_filters)
        if not item:
            abort(404)
        try:
            self.pre_delete(item)
        except Exception as e:
            flash(str(e), 'danger')
        else:
            view_menu = security_manager.find_view_menu(item.get_perm())
            pvs = security_manager.get_session.query(
                security_manager.permissionview_model).filter_by(
                view_menu=view_menu).all()

            schema_view_menu = None
            if hasattr(item, 'schema_perm'):
                schema_view_menu = security_manager.find_view_menu(item.schema_perm)

                pvs.extend(security_manager.get_session.query(
                    security_manager.permissionview_model).filter_by(
                    view_menu=schema_view_menu).all())

            if self.datamodel.delete(item):
                self.post_delete(item)

                for pv in pvs:
                    security_manager.get_session.delete(pv)

                if view_menu:
                    security_manager.get_session.delete(view_menu)

                if schema_view_menu:
                    security_manager.get_session.delete(schema_view_menu)

                security_manager.get_session.commit()

            flash(*self.datamodel.message)
            self.update_redirect()

    @action(
        'muldelete',
        __('Delete'),
        __('Delete all Really?'),
        'fa-trash',
        single=False,
    )
    def muldelete(self, items):
        if not items:
            abort(404)
        for item in items:
            try:
                self.pre_delete(item)
            except Exception as e:
                flash(str(e), 'danger')
            else:
                self._delete(item.id)
        self.update_redirect()
        return redirect(self.get_redirect())
def get_datasource_exist_error_mgs(full_name):
    return __("Datasource %(name)s already exists", name=full_name)
Exemple #49
0
def execute_sql_statement(  # pylint: disable=too-many-arguments,too-many-locals
    sql_statement: str,
    query: Query,
    user_name: Optional[str],
    session: Session,
    cursor: Any,
    log_params: Optional[Dict[str, Any]],
    apply_ctas: bool = False,
) -> SupersetResultSet:
    """Executes a single SQL statement"""
    database = query.database
    db_engine_spec = database.db_engine_spec
    parsed_query = ParsedQuery(sql_statement)
    sql = parsed_query.stripped()
    # This is a test to see if the query is being
    # limited by either the dropdown or the sql.
    # We are testing to see if more rows exist than the limit.
    increased_limit = None if query.limit is None else query.limit + 1

    if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
        raise SupersetErrorException(
            SupersetError(
                message=__("Only SELECT statements are allowed against this database."),
                error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                level=ErrorLevel.ERROR,
            )
        )
    if apply_ctas:
        if not query.tmp_table_name:
            start_dttm = datetime.fromtimestamp(query.start_time)
            query.tmp_table_name = "tmp_{}_table_{}".format(
                query.user_id, start_dttm.strftime("%Y_%m_%d_%H_%M_%S")
            )
        sql = parsed_query.as_create_table(
            query.tmp_table_name,
            schema_name=query.tmp_schema_name,
            method=query.ctas_method,
        )
        query.select_as_cta_used = True

    # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
    if db_engine_spec.is_select_query(parsed_query) and not (
        query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
    ):
        if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
            query.limit = SQL_MAX_ROW
        sql = apply_limit_if_exists(database, increased_limit, query, sql)

    # Hook to allow environment-specific mutation (usually comments) to the SQL
    sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
    try:
        query.executed_sql = sql
        if log_query:
            log_query(
                query.database.sqlalchemy_uri,
                query.executed_sql,
                query.schema,
                user_name,
                __name__,
                security_manager,
                log_params,
            )
        session.commit()
        with stats_timing("sqllab.query.time_executing_query", stats_logger):
            logger.debug("Query %d: Running query: %s", query.id, sql)
            db_engine_spec.execute(cursor, sql, async_=True)
            logger.debug("Query %d: Handling cursor", query.id)
            db_engine_spec.handle_cursor(cursor, query, session)

        with stats_timing("sqllab.query.time_fetching_results", stats_logger):
            logger.debug(
                "Query %d: Fetching data for query object: %s",
                query.id,
                str(query.to_dict()),
            )
            data = db_engine_spec.fetch_data(cursor, increased_limit)
            if query.limit is None or len(data) <= query.limit:
                query.limiting_factor = LimitingFactor.NOT_LIMITED
            else:
                # return 1 row less than increased_query
                data = data[:-1]
    except SoftTimeLimitExceeded as ex:
        logger.warning("Query %d: Time limit exceeded", query.id)
        logger.debug("Query %d: %s", query.id, ex)
        raise SupersetErrorException(
            SupersetError(
                message=__(
                    "The query was killed after %(sqllab_timeout)s seconds. It might "
                    "be too complex, or the database might be under heavy load.",
                    sqllab_timeout=SQLLAB_TIMEOUT,
                ),
                error_type=SupersetErrorType.SQLLAB_TIMEOUT_ERROR,
                level=ErrorLevel.ERROR,
            )
        ) from ex
    except Exception as ex:
        # query is stopped in another thread/worker
        # stopping raises expected exceptions which we should skip
        session.refresh(query)
        if query.status == QueryStatus.STOPPED:
            raise SqlLabQueryStoppedException() from ex

        logger.error("Query %d: %s", query.id, type(ex), exc_info=True)
        logger.debug("Query %d: %s", query.id, ex)
        raise SqlLabException(db_engine_spec.extract_error_message(ex)) from ex

    logger.debug("Query %d: Fetching cursor description", query.id)
    cursor_description = cursor.description
    return SupersetResultSet(data, cursor_description, db_engine_spec)
Exemple #50
0
def get_datasource_exist_error_msg(full_name):
    return __('Datasource %(name)s already exists', name=full_name)
            obj.end_dttm = obj.start_dttm
        elif obj.end_dttm < obj.start_dttm:
            raise Exception(
                "Annotation end time must be no earlier than start time.")

    def pre_update(self, obj):
        self.pre_add(obj)


class AnnotationLayerModelView(SupersetModelView, DeleteMixin):
    datamodel = SQLAInterface(AnnotationLayer)
    list_columns = ['id', 'name']
    edit_columns = ['name', 'descr']
    add_columns = edit_columns


appbuilder.add_view(AnnotationLayerModelView,
                    "Annotation Layers",
                    label=__("Annotation Layers"),
                    icon="fa-comment",
                    category="Manage",
                    category_label=__("Manage"),
                    category_icon='')
appbuilder.add_view(AnnotationModelView,
                    "Annotations",
                    label=__("Annotations"),
                    icon="fa-comments",
                    category="Manage",
                    category_label=__("Manage"),
                    category_icon='')
Exemple #52
0
def deliver_dashboard(
    dashboard_id: int,
    recipients: Optional[str],
    slack_channel: Optional[str],
    delivery_type: EmailDeliveryType,
    deliver_as_group: bool,
) -> None:
    """
    Given a schedule, delivery the dashboard as an email report
    """
    dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()

    dashboard_url = _get_url_path("Superset.dashboard",
                                  dashboard_id_or_slug=dashboard.id)
    dashboard_url_user_friendly = _get_url_path(
        "Superset.dashboard",
        user_friendly=True,
        dashboard_id_or_slug=dashboard.id)

    # Create a driver, fetch the page, wait for the page to render
    driver = create_webdriver()
    window = config["WEBDRIVER_WINDOW"]["dashboard"]
    driver.set_window_size(*window)
    driver.get(dashboard_url)
    time.sleep(EMAIL_PAGE_RENDER_WAIT)

    # Set up a function to retry once for the element.
    # This is buggy in certain selenium versions with firefox driver
    get_element = getattr(driver, "find_element_by_class_name")
    element = retry_call(get_element,
                         fargs=["grid-container"],
                         tries=2,
                         delay=EMAIL_PAGE_RENDER_WAIT)

    try:
        screenshot = element.screenshot_as_png
    except WebDriverException:
        # Some webdrivers do not support screenshots for elements.
        # In such cases, take a screenshot of the entire page.
        screenshot = driver.screenshot()  # pylint: disable=no-member
    finally:
        destroy_webdriver(driver)

    # Generate the email body and attachments
    report_content = _generate_report_content(
        delivery_type,
        screenshot,
        dashboard.dashboard_title,
        dashboard_url_user_friendly,
    )

    subject = __(
        "%(prefix)s %(title)s",
        prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
        title=dashboard.dashboard_title,
    )

    if recipients:
        _deliver_email(
            recipients,
            deliver_as_group,
            subject,
            report_content.body,
            report_content.data,
            report_content.images,
        )
    if slack_channel:
        deliver_slack_msg(
            slack_channel,
            subject,
            report_content.slack_message,
            report_content.slack_attachment,
        )
    }

    def pre_add(self, cluster):
        security_manager.merge_perm('database_access', cluster.perm)

    def pre_update(self, cluster):
        self.pre_add(cluster)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)


appbuilder.add_view(
    DruidClusterModelView,
    name='Druid Clusters',
    label=__('Druid Clusters'),
    icon='fa-cubes',
    category='Sources',
    category_label=__('Sources'),
    category_icon='fa-database',
)


class DruidDatasourceModelView(DatasourceModelView, DeleteMixin,
                               YamlExportMixin):  # noqa
    datamodel = SQLAInterface(models.DruidDatasource)

    list_title = _('List Druid Datasource')
    show_title = _('Show Druid Datasource')
    add_title = _('Add Druid Datasource')
    edit_title = _('Edit Druid Datasource')
Exemple #54
0
            security.merge_perm(sm, 'schema_access', table.schema_perm)

        if flash_message:
            flash(_(
                "The table was created. "
                "As part of this two phase configuration "
                "process, you should now click the edit button by "
                "the new table to configure it."), "info")

    def post_update(self, table):
        self.post_add(table, flash_message=False)

    @expose('/edit/<pk>', methods=['GET', 'POST'])
    @has_access
    def edit(self, pk):
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, basestring):
            return resp
        return redirect('/superset/explore/table/{}/'.format(pk))

appbuilder.add_view(
    TableModelView,
    "Tables",
    label=__("Tables"),
    category="Sources",
    category_label=__("Sources"),
    icon='fa-table',)

appbuilder.add_separator("Sources")
Exemple #55
0
def _get_slice_data(slc: Slice,
                    delivery_type: EmailDeliveryType) -> ReportContent:
    slice_url = _get_url_path("Superset.explore_json",
                              csv="true",
                              form_data=json.dumps({"slice_id": slc.id}))

    # URL to include in the email
    slice_url_user_friendly = _get_url_path("Superset.slice",
                                            slice_id=slc.id,
                                            user_friendly=True)

    # Login on behalf of the "reports" user in order to get cookies to deal with auth
    auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(
        get_reports_user())
    # Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie"
    cookie_str = ";".join(
        [f"{key}={val}" for key, val in auth_cookies.items()])

    opener = urllib.request.build_opener()
    opener.addheaders.append(("Cookie", cookie_str))
    response = opener.open(slice_url)
    if response.getcode() != 200:
        raise URLError(response.getcode())

    # TODO: Move to the csv module
    content = response.read()
    rows = [r.split(b",") for r in content.splitlines()]

    if delivery_type == EmailDeliveryType.inline:
        data = None

        # Parse the csv file and generate HTML
        columns = rows.pop(0)
        with app.app_context():  # type: ignore
            body = render_template(
                "superset/reports/slice_data.html",
                columns=columns,
                rows=rows,
                name=slc.slice_name,
                link=slice_url_user_friendly,
            )

    elif delivery_type == EmailDeliveryType.attachment:
        data = {__("%(name)s.csv", name=slc.slice_name): content}
        body = __(
            '<b><a href="%(url)s">Explore in Superset</a></b><p></p>',
            name=slc.slice_name,
            url=slice_url_user_friendly,
        )

    # how to: https://api.slack.com/reference/surfaces/formatting
    slack_message = __(
        """
        *%(slice_name)s*\n
        <%(slice_url_user_friendly)s|Explore in Superset>
        """,
        slice_name=slc.slice_name,
        slice_url_user_friendly=slice_url_user_friendly,
    )

    return ReportContent(body, data, None, slack_message, content)
Exemple #56
0
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                'Metadata refreshed for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in successes]))
            flash(success_msg, 'info')
        if len(failures) > 0:
            failure_msg = _(
                'Unable to retrieve metadata for the following table(s): %(tables)s',
                tables=', '.join([t.table_name for t in failures]))
            flash(failure_msg, 'danger')

        return redirect('/tablemodelview/list/')


appbuilder.add_view_no_menu(TableModelView)
appbuilder.add_link(
    'Tables',
    label=__('Tables'),
    href='/tablemodelview/list/?_flt_1_is_sqllab_view=y',
    icon='fa-table',
    category='Sources',
    category_label=__('Sources'),
    category_icon='fa-table')

appbuilder.add_separator('Sources')
Exemple #57
0
class DashboardModelView(DashboardMixin, SupersetModelView, DeleteMixin):  # pylint: disable=too-many-ancestors
    route_base = "/dashboard"
    datamodel = SQLAInterface(models.Dashboard)
    # TODO disable api_read and api_delete (used by cypress)
    # once we move to ChartRestModelApi
    include_route_methods = RouteMethod.CRUD_SET | {
        RouteMethod.API_READ,
        RouteMethod.API_DELETE,
        "download_dashboards",
    }

    @has_access
    @expose("/list/")
    def list(self):
        if not app.config["ENABLE_REACT_CRUD_VIEWS"]:
            return super().list()
        payload = {
            "user": bootstrap_user_data(g.user),
            "common": common_bootstrap_payload(),
        }
        return self.render_template(
            "superset/welcome.html",
            entry="welcome",
            bootstrap_data=json.dumps(
                payload, default=utils.pessimistic_json_iso_dttm_ser),
        )

    @action("mulexport", __("Export"), __("Export dashboards?"), "fa-database")
    def mulexport(self, items):  # pylint: disable=no-self-use
        if not isinstance(items, list):
            items = [items]
        ids = "".join("&id={}".format(d.id) for d in items)
        return redirect("/dashboard/export_dashboards_form?{}".format(ids[1:]))

    @event_logger.log_this
    @has_access
    @expose("/export_dashboards_form")
    def download_dashboards(self):
        if request.args.get("action") == "go":
            ids = request.args.getlist("id")
            return Response(
                models.Dashboard.export_dashboards(ids),
                headers=generate_download_headers("json"),
                mimetype="application/text",
            )
        return self.render_template("superset/export_dashboards.html",
                                    dashboards_url="/dashboard/list")

    def pre_add(self, item):
        item.slug = item.slug or None
        if item.slug:
            item.slug = item.slug.strip()
            item.slug = item.slug.replace(" ", "-")
            item.slug = re.sub(r"[^\w\-]+", "", item.slug)
        if g.user not in item.owners:
            item.owners.append(g.user)
        utils.validate_json(item.json_metadata)
        utils.validate_json(item.position_json)
        owners = [o for o in item.owners]
        for slc in item.slices:
            slc.owners = list(set(owners) | set(slc.owners))

    def pre_update(self, item):
        check_ownership(item)
        self.pre_add(item)
Exemple #58
0
    def explore(self, datasource_type, datasource_id):

        error_redirect = '/slicemodelview/list/'
        datasource_class = models.SqlaTable \
            if datasource_type == "table" else models.DruidDatasource
        datasources = (
            db.session
            .query(datasource_class)
            .all()
        )
        datasources = sorted(datasources, key=lambda ds: ds.full_name)
        datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
        datasource = datasource[0] if datasource else None
        slice_id = request.args.get("slice_id")
        slc = None

        if slice_id:
            slc = (
                db.session.query(models.Slice)
                .filter_by(id=slice_id)
                .first()
            )
        if not datasource:
            flash(__("The datasource seems to have been deleted"), "alert")
            return redirect(error_redirect)

        slice_add_perm = self.can_access('can_add', 'SliceModelView')
        slice_edit_perm = check_ownership(slc, raise_if_false=False)
        slice_download_perm = self.can_access('can_download', 'SliceModelView')

        all_datasource_access = self.can_access(
            'all_datasource_access', 'all_datasource_access')
        datasource_access = self.can_access(
            'datasource_access', datasource.perm)
        if not (all_datasource_access or datasource_access):
            flash(__("You don't seem to have access to this datasource"), "danger")
            return redirect(error_redirect)

        action = request.args.get('action')
        if action in ('saveas', 'overwrite'):
            return self.save_or_overwrite_slice(
                request.args, slc, slice_add_perm, slice_edit_perm)

        viz_type = request.args.get("viz_type")
        if not viz_type and datasource.default_endpoint:
            return redirect(datasource.default_endpoint)
        if not viz_type:
            viz_type = "table"
        try:
            obj = viz.viz_types[viz_type](
                datasource,
                form_data=request.args,
                slice_=slc)
        except Exception as e:
            flash(str(e), "danger")
            return redirect(error_redirect)
        if request.args.get("json") == "true":
            status = 200
            if config.get("DEBUG"):
                # Allows for nice debugger stack traces in debug mode
                payload = obj.get_json()
            else:
                try:
                    payload = obj.get_json()
                except Exception as e:
                    logging.exception(e)
                    payload = str(e)
                    status = 500
            resp = Response(
                payload,
                status=status,
                mimetype="application/json")
            return resp
        elif request.args.get("csv") == "true":
            payload = obj.get_csv()
            return Response(
                payload,
                status=200,
                headers=generate_download_headers("csv"),
                mimetype="application/csv")
        else:
            if request.args.get("standalone") == "true":
                template = "caravel/standalone.html"
            else:
                template = "caravel/explore.html"
            resp = self.render_template(
                template, viz=obj, slice=slc, datasources=datasources,
                can_add=slice_add_perm, can_edit=slice_edit_perm,
                can_download=slice_download_perm,
                userid=g.user.get_id() if g.user else '')
            try:
                pass
            except Exception as e:
                if config.get("DEBUG"):
                    raise(e)
                return Response(
                    str(e),
                    status=500,
                    mimetype="application/json")
            return resp
Exemple #59
0
class BigQueryEngineSpec(BaseEngineSpec):
    """Engine spec for Google's BigQuery

    As contributed by @mxmzdlv on issue #945"""

    engine = "bigquery"
    engine_name = "Google BigQuery"
    max_column_name_length = 128

    # BigQuery doesn't maintain context when running multiple statements in the
    # same cursor, so we need to run all statements at once
    run_multiple_statements_as_one = True
    """
    https://www.python.org/dev/peps/pep-0249/#arraysize
    raw_connections bypass the pybigquery query execution context and deal with
    raw dbapi connection directly.
    If this value is not set, the default value is set to 1, as described here,
    https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor

    The default value of 5000 is derived from the pybigquery.
    https://github.com/mxmzdlv/pybigquery/blob/d214bb089ca0807ca9aaa6ce4d5a01172d40264e/pybigquery/sqlalchemy_bigquery.py#L102
    """
    arraysize = 5000

    _date_trunc_functions = {
        "DATE": "DATE_TRUNC",
        "DATETIME": "DATETIME_TRUNC",
        "TIME": "TIME_TRUNC",
        "TIMESTAMP": "TIMESTAMP_TRUNC",
    }

    _time_grain_expressions = {
        None:
        "{col}",
        "PT1S":
        "{func}({col}, SECOND)",
        "PT1M":
        "{func}({col}, MINUTE)",
        "PT5M":
        "CAST(TIMESTAMP_SECONDS("
        "5*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 5*60)"
        ") AS {type})",
        "PT10M":
        "CAST(TIMESTAMP_SECONDS("
        "10*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 10*60)"
        ") AS {type})",
        "PT15M":
        "CAST(TIMESTAMP_SECONDS("
        "15*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 15*60)"
        ") AS {type})",
        "PT0.5H":
        "CAST(TIMESTAMP_SECONDS("
        "30*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 30*60)"
        ") AS {type})",
        "PT1H":
        "{func}({col}, HOUR)",
        "P1D":
        "{func}({col}, DAY)",
        "P1W":
        "{func}({col}, WEEK)",
        "P1M":
        "{func}({col}, MONTH)",
        "P0.25Y":
        "{func}({col}, QUARTER)",
        "P1Y":
        "{func}({col}, YEAR)",
    }

    custom_errors = {
        CONNECTION_DATABASE_PERMISSIONS_REGEX: (
            __("We were unable to connect to your database. Please "
               "confirm that your service account has the Viewer "
               "and Job User roles on the project."),
            SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR,
        ),
    }

    @classmethod
    def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"CAST('{dttm.date().isoformat()}' AS DATE)"
        if tt == utils.TemporalType.DATETIME:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
        if tt == utils.TemporalType.TIME:
            return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
        if tt == utils.TemporalType.TIMESTAMP:
            return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
        return None

    @classmethod
    def fetch_data(cls,
                   cursor: Any,
                   limit: Optional[int] = None) -> List[Tuple[Any, ...]]:
        data = super().fetch_data(cursor, limit)
        # Support type BigQuery Row, introduced here PR #4071
        # google.cloud.bigquery.table.Row
        if data and type(data[0]).__name__ == "Row":
            data = [r.values() for r in data]  # type: ignore
        return data

    @staticmethod
    def _mutate_label(label: str) -> str:
        """
        BigQuery field_name should start with a letter or underscore and contain only
        alphanumeric characters. Labels that start with a number are prefixed with an
        underscore. Any unsupported characters are replaced with underscores and an
        md5 hash is added to the end of the label to avoid possible collisions.

        :param label: Expected expression label
        :return: Conditionally mutated label
        """
        label_hashed = "_" + hashlib.md5(label.encode("utf-8")).hexdigest()

        # if label starts with number, add underscore as first character
        label_mutated = "_" + label if re.match(r"^\d", label) else label

        # replace non-alphanumeric characters with underscores
        label_mutated = re.sub(r"[^\w]+", "_", label_mutated)
        if label_mutated != label:
            # add first 5 chars from md5 hash to label to avoid possible collisions
            label_mutated += label_hashed[:6]

        return label_mutated

    @classmethod
    def _truncate_label(cls, label: str) -> str:
        """BigQuery requires column names start with either a letter or
        underscore. To make sure this is always the case, an underscore is prefixed
        to the md5 hash of the original label.

        :param label: expected expression label
        :return: truncated label
        """
        return "_" + hashlib.md5(label.encode("utf-8")).hexdigest()

    @classmethod
    def normalize_indexes(
            cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Normalizes indexes for more consistency across db engines

        :param indexes: Raw indexes as returned by SQLAlchemy
        :return: cleaner, more aligned index definition
        """
        normalized_idxs = []
        # Fixing a bug/behavior observed in pybigquery==0.4.15 where
        # the index's `column_names` == [None]
        # Here we're returning only non-None indexes
        for ix in indexes:
            column_names = ix.get("column_names") or []
            ix["column_names"] = [
                col for col in column_names if col is not None
            ]
            if ix["column_names"]:
                normalized_idxs.append(ix)
        return normalized_idxs

    @classmethod
    def extra_table_metadata(cls, database: "Database", table_name: str,
                             schema_name: str) -> Dict[str, Any]:
        indexes = database.get_indexes(table_name, schema_name)
        if not indexes:
            return {}
        partitions_columns = [
            index.get("column_names", []) for index in indexes
            if index.get("name") == "partition"
        ]
        cluster_columns = [
            index.get("column_names", []) for index in indexes
            if index.get("name") == "clustering"
        ]
        return {
            "partitions": {
                "cols": partitions_columns
            },
            "clustering": {
                "cols": cluster_columns
            },
        }

    @classmethod
    def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
        """
        BigQuery dialect requires us to not use backtick in the fieldname which are
        nested.
        Using literal_column handles that issue.
        https://docs.sqlalchemy.org/en/latest/core/tutorial.html#using-more-specific-text-with-table-literal-column-and-column
        Also explicility specifying column names so we don't encounter duplicate
        column names in the result.
        """
        return [
            literal_column(c["name"]).label(c["name"].replace(".", "__"))
            for c in cols
        ]

    @classmethod
    def epoch_to_dttm(cls) -> str:
        return "TIMESTAMP_SECONDS({col})"

    @classmethod
    def epoch_ms_to_dttm(cls) -> str:
        return "TIMESTAMP_MILLIS({col})"

    @classmethod
    def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
        """
        Upload data from a Pandas DataFrame to BigQuery. Calls
        `DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.

        :param df: Dataframe with data to be uploaded
        :param kwargs: kwargs to be passed to to_gbq() method. Requires that `schema`,
        `name` and `con` are present in kwargs. `name` and `schema` are combined
         and passed to `to_gbq()` as `destination_table`.
        """
        try:
            import pandas_gbq
            from google.oauth2 import service_account
        except ImportError:
            raise Exception(
                "Could not import libraries `pandas_gbq` or `google.oauth2`, which are "
                "required to be installed in your environment in order "
                "to upload data to BigQuery")

        if not ("name" in kwargs and "schema" in kwargs and "con" in kwargs):
            raise Exception(
                "name, schema and con need to be defined in kwargs")

        gbq_kwargs = {}
        gbq_kwargs["project_id"] = kwargs["con"].engine.url.host
        gbq_kwargs[
            "destination_table"] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}"

        # add credentials if they are set on the SQLAlchemy Dialect:
        creds = kwargs["con"].dialect.credentials_info
        if creds:
            credentials = service_account.Credentials.from_service_account_info(
                creds)
            gbq_kwargs["credentials"] = credentials

        # Only pass through supported kwargs
        supported_kwarg_keys = {"if_exists"}
        for key in supported_kwarg_keys:
            if key in kwargs:
                gbq_kwargs[key] = kwargs[key]
        pandas_gbq.to_gbq(df, **gbq_kwargs)