Esempio n. 1
0
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.SqlMetric)
    class_permission_name = "Dataset"
    method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
    include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET

    list_title = _("Metrics")
    show_title = _("Show Metric")
    add_title = _("Add Metric")
    edit_title = _("Edit Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "expression",
        "table",
        "d3format",
        "extra",
        "warning_text",
    ]
    description_columns = {
        "expression":
        utils.markdown(
            "a valid, *aggregating* SQL expression as supported by the "
            "underlying backend. Example: `count(DISTINCT userid)`",
            True,
        ),
        "d3format":
        utils.markdown(
            "d3 formatting string as defined [here]"
            "(https://github.com/d3/d3-format/blob/master/README.md#format). "
            "For instance, this default formatting applies in the Table "
            "visualization and allow for different metric to use different "
            "formats",
            True,
        ),
        "extra":
        utils.markdown(
            "Extra data to specify metric metadata. Currently supports "
            'metadata of the format: `{ "certification": { "certified_by": '
            '"Data Platform Team", "details": "This metric is the source of truth." '
            '}, "warning_markdown": "This is a warning." }`. This should be modified '
            "from the edit datasource model in Explore to ensure correct formatting.",
            True,
        ),
    }
    add_columns = edit_columns
    page_size = 500
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "expression": _("SQL Expression"),
        "table": _("Table"),
        "d3format": _("D3 Format"),
        "extra": _("Extra"),
        "warning_text": _("Warning Message"),
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session.query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Esempio n. 2
0
class TableModelView(  # pylint: disable=too-many-ancestors
        DatasourceModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.SqlaTable)
    include_route_methods = RouteMethod.CRUD_SET

    list_title = _("Tables")
    show_title = _("Show Table")
    add_title = _("Import a table definition")
    edit_title = _("Edit Table")

    list_columns = ["link", "database_name", "changed_by_", "modified"]
    order_columns = ["modified"]
    add_columns = ["database", "schema", "table_name"]
    edit_columns = [
        "table_name",
        "sql",
        "filter_select_enabled",
        "fetch_values_predicate",
        "database",
        "schema",
        "description",
        "owners",
        "main_dttm_col",
        "default_endpoint",
        "offset",
        "cache_timeout",
        "is_sqllab_view",
        "template_params",
        "extra",
    ]
    base_filters = [["id", DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ["perm", "slices"]
    related_views = [
        TableColumnInlineView,
        SqlMetricInlineView,
    ]
    base_order = ("changed_on", "desc")
    search_columns = ("database", "schema", "table_name", "owners",
                      "is_sqllab_view")
    description_columns = {
        "slices":
        _("The list of charts associated with this table. By "
          "altering this datasource, you may change how these associated "
          "charts behave. "
          "Also note that charts need to point to a datasource, so "
          "this form will fail at saving if removing charts from a "
          "datasource. If you want to change the datasource for a chart, "
          "overwrite the chart from the 'explore view'"),
        "offset":
        _("Timezone offset (in hours) for this datasource"),
        "table_name":
        _("Name of the table that exists in the source database"),
        "schema":
        _("Schema, as used only in some databases like Postgres, Redshift "
          "and DB2"),
        "description":
        Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            "markdown</a>"),
        "sql":
        _("This fields acts a Superset view, meaning that Superset will "
          "run a query against this string as a subquery."),
        "fetch_values_predicate":
        _("Predicate applied when fetching distinct value to "
          "populate the filter control component. Supports "
          "jinja template syntax. Applies only when "
          "`Enable Filter Select` is on."),
        "default_endpoint":
        _("Redirects to this endpoint when clicking on the table "
          "from the table list"),
        "filter_select_enabled":
        _("Whether to populate the filter's dropdown in the explore "
          "view's filter section with a list of distinct values fetched "
          "from the backend on the fly"),
        "is_sqllab_view":
        _("Whether the table was generated by the 'Visualize' flow "
          "in SQL Lab"),
        "template_params":
        _("A set of parameters that become available in the query using "
          "Jinja templating syntax"),
        "cache_timeout":
        _("Duration (in seconds) of the caching timeout for this table. "
          "A timeout of 0 indicates that the cache never expires. "
          "Note this defaults to the database timeout if undefined."),
        "extra":
        utils.markdown(
            "Extra data to specify table metadata. Currently supports "
            'certification data of the format: `{ "certification": { "certified_by": '
            '"Data Platform Team", "details": "This table is the source of truth." '
            "} }`.",
            True,
        ),
    }
    label_columns = {
        "slices": _("Associated Charts"),
        "link": _("Table"),
        "changed_by_": _("Changed By"),
        "database": _("Database"),
        "database_name": _("Database"),
        "changed_on_": _("Last Changed"),
        "filter_select_enabled": _("Enable Filter Select"),
        "schema": _("Schema"),
        "default_endpoint": _("Default Endpoint"),
        "offset": _("Offset"),
        "cache_timeout": _("Cache Timeout"),
        "table_name": _("Table Name"),
        "fetch_values_predicate": _("Fetch Values Predicate"),
        "owners": _("Owners"),
        "main_dttm_col": _("Main Datetime Column"),
        "description": _("Description"),
        "is_sqllab_view": _("SQL Lab View"),
        "template_params": _("Template parameters"),
        "extra": _("Extra"),
        "modified": _("Modified"),
    }
    edit_form_extra_fields = {
        "database":
        QuerySelectField(
            "Database",
            query_factory=lambda: db.session.query(models.Database),
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    def pre_add(self, item: "TableModelView") -> None:
        validate_sqlatable(item)

    def post_add(  # pylint: disable=arguments-differ
            self,
            item: "TableModelView",
            flash_message: bool = True) -> None:
        item.fetch_metadata()
        create_table_permissions(item)
        if flash_message:
            flash(
                _("The table was created. "
                  "As part of this two-phase configuration "
                  "process, you should now click the edit button by "
                  "the new table to configure it."),
                "info",
            )

    def post_update(self, item: "TableModelView") -> None:
        self.post_add(item, flash_message=False)

    def _delete(self, pk: int) -> None:
        DeleteMixin._delete(self, pk)

    @expose("/edit/<pk>", methods=["GET", "POST"])
    @has_access
    def edit(self, pk: int) -> FlaskResponse:
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, str):
            return resp
        return redirect("/superset/explore/table/{}/".format(pk))

    @action("refresh", __("Refresh Metadata"), __("Refresh column metadata"),
            "fa-refresh")
    def refresh(  # pylint: disable=no-self-use, too-many-branches
        self, tables: Union["TableModelView",
                            List["TableModelView"]]) -> FlaskResponse:
        if not isinstance(tables, list):
            tables = [tables]

        @dataclass
        class RefreshResults:
            successes: List[TableModelView] = field(default_factory=list)
            failures: List[TableModelView] = field(default_factory=list)
            added: Dict[str, List[str]] = field(default_factory=dict)
            removed: Dict[str, List[str]] = field(default_factory=dict)
            modified: Dict[str, List[str]] = field(default_factory=dict)

        results = RefreshResults()

        for table_ in tables:
            try:
                metadata_results = table_.fetch_metadata()
                if metadata_results.added:
                    results.added[table_.table_name] = metadata_results.added
                if metadata_results.removed:
                    results.removed[
                        table_.table_name] = metadata_results.removed
                if metadata_results.modified:
                    results.modified[
                        table_.table_name] = metadata_results.modified
                results.successes.append(table_)
            except Exception:  # pylint: disable=broad-except
                results.failures.append(table_)

        if len(results.successes) > 0:
            success_msg = _(
                "Metadata refreshed for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in results.successes]),
            )
            flash(success_msg, "info")
        if results.added:
            added_tables = []
            for table, cols in results.added.items():
                added_tables.append(f"{table} ({', '.join(cols)})")
            flash(
                _(
                    "The following tables added new columns: %(tables)s",
                    tables=", ".join(added_tables),
                ),
                "info",
            )
        if results.removed:
            removed_tables = []
            for table, cols in results.removed.items():
                removed_tables.append(f"{table} ({', '.join(cols)})")
            flash(
                _(
                    "The following tables removed columns: %(tables)s",
                    tables=", ".join(removed_tables),
                ),
                "info",
            )
        if results.modified:
            modified_tables = []
            for table, cols in results.modified.items():
                modified_tables.append(f"{table} ({', '.join(cols)})")
            flash(
                _(
                    "The following tables update column metadata: %(tables)s",
                    tables=", ".join(modified_tables),
                ),
                "info",
            )
        if len(results.failures) > 0:
            failure_msg = _(
                "Unable to refresh metadata for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in results.failures]),
            )
            flash(failure_msg, "danger")

        return redirect("/tablemodelview/list/")

    @expose("/list/")
    @has_access
    def list(self) -> FlaskResponse:
        if not is_feature_enabled("ENABLE_REACT_CRUD_VIEWS"):
            return super().list()

        return super().render_app_template()
Esempio n. 3
0
class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.DruidCluster)

    list_title = _("Druid Clusters")
    show_title = _("Show Druid Cluster")
    add_title = _("Add Druid Cluster")
    edit_title = _("Edit Druid Cluster")

    add_columns = [
        "verbose_name",
        "broker_host",
        "broker_port",
        "broker_user",
        "broker_pass",
        "broker_endpoint",
        "cache_timeout",
        "cluster_name",
    ]
    edit_columns = add_columns
    list_columns = ["cluster_name", "metadata_last_refreshed"]
    search_columns = ("cluster_name", )
    label_columns = {
        "cluster_name": _("Cluster"),
        "broker_host": _("Broker Host"),
        "broker_port": _("Broker Port"),
        "broker_user": _("Broker Username"),
        "broker_pass": _("Broker Password"),
        "broker_endpoint": _("Broker Endpoint"),
        "verbose_name": _("Verbose Name"),
        "cache_timeout": _("Cache Timeout"),
        "metadata_last_refreshed": _("Metadata Last Refreshed"),
    }
    description_columns = {
        "cache_timeout":
        _("Duration (in seconds) of the caching timeout for this cluster. "
          "A timeout of 0 indicates that the cache never expires. "
          "Note this defaults to the global timeout if undefined."),
        "broker_user":
        _("Druid supports basic authentication. See "
          "[auth](http://druid.io/docs/latest/design/auth.html) and "
          "druid-basic-security extension"),
        "broker_pass":
        _("Druid supports basic authentication. See "
          "[auth](http://druid.io/docs/latest/design/auth.html) and "
          "druid-basic-security extension"),
    }

    yaml_dict_key = "databases"

    edit_form_extra_fields = {
        "cluster_name":
        QuerySelectField(
            "Cluster",
            query_factory=lambda: db.session().query(models.DruidCluster),
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    def pre_add(self, cluster):
        security_manager.add_permission_view_menu("database_access",
                                                  cluster.perm)

    def pre_update(self, cluster):
        self.pre_add(cluster)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)
Esempio n. 4
0
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.DruidMetric)

    list_title = _("Metrics")
    show_title = _("Show Druid Metric")
    add_title = _("Add Druid Metric")
    edit_title = _("Edit Druid Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "json",
        "datasource",
        "d3format",
        "is_restricted",
        "warning_text",
    ]
    add_columns = edit_columns
    page_size = 500
    validators_columns = {"json": [validate_json]}
    description_columns = {
        "metric_type": utils.markdown(
            "use `postagg` as the metric type if you are defining a "
            "[Druid Post Aggregation]"
            "(http://druid.io/docs/latest/querying/post-aggregations.html)",
            True,
        ),
        "is_restricted": _(
            "Whether access to this metric is restricted "
            "to certain roles. Only roles with the permission "
            "'metric access on XXX (the name of this metric)' "
            "are allowed to access this metric"
        ),
    }
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "json": _("JSON"),
        "datasource": _("Druid Datasource"),
        "warning_text": _("Warning Message"),
        "is_restricted": _("Is Restricted"),
    }

    add_form_extra_fields = {
        "datasource": QuerySelectField(
            "Datasource",
            query_factory=lambda: db.session().query(models.DruidDatasource),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields

    def post_add(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu(
                "metric_access", metric.get_perm()
            )

    def post_update(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu(
                "metric_access", metric.get_perm()
            )
Esempio n. 5
0
class ReportForm(DynamicForm):
    report_id = HiddenField()
    schedule_timezone = HiddenField()
    report_title = StringField(
        ("Title"),
        description="Title will be used as the report's name",
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()],
    )
    description = TextAreaField(("Description"),
                                widget=BS3TextAreaFieldWidget(),
                                validators=[DataRequired()])
    owner_name = StringField(("Owner Name"),
                             widget=BS3TextFieldWidget(),
                             validators=[DataRequired()])
    owner_email = StringField(
        ("Owner Email"),
        description="Owner email will be added to the subscribers list",
        widget=BS3TextFieldWidget(),
        validators=[DataRequired(), Email()],
    )
    subscribers = StringField(
        ("Subscribers"),
        description=("List of comma separeted emails that should receive email\
             notifications. Automatically adds owner email to this list."),
        widget=BS3TextFieldWidget(),
    )
    tests = SelectMultipleField(
        ("Tests"),
        description=(
            "List of the tests to include in the report. Only includes\
         tasks that have ran in airflow."),
        choices=None,
        widget=Select2ManyWidget(),
        validators=[DataRequired()],
    )
    schedule_type = SelectField(
        ("Schedule"),
        description=("Select how you want to schedule the report"),
        choices=[
            ("manual", "None (Manual triggering)"),
            ("daily", "Daily"),
            ("weekly", "Weekly"),
            ("custom", "Custom (Cron)"),
        ],
        widget=Select2Widget(),
        validators=[DataRequired()],
    )
    schedule_time = TimeField(
        "Time",
        description=("Note that time zone being used is the "
                     "selected timezone in your clock interface."),
        render_kw={"class": "form-control"},
        validators=[DataRequired()],
    )
    schedule_week_day = SelectField(
        ("Day of week"),
        description=("Select day of a week you want to schedule the report"),
        choices=[
            ("0", "Sunday"),
            ("1", "Monday"),
            ("2", "Tuesday"),
            ("3", "Wednesday"),
            ("4", "Thursday"),
            ("5", "Friday"),
            ("6", "Saturday"),
        ],
        widget=Select2Widget(),
        validators=[DataRequired()],
    )
    schedule_custom = StringField(
        ("Cron schedule"),
        description=('Enter cron schedule (e.g. "0 0 * * *").\
         Note that time zone being used is UTC.'),
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()],
    )
Esempio n. 6
0
class EtlTableView(SupersetModelView, DeleteMixin):
    """View For EtlTable Model."""

    datamodel = SQLAInterface(EtlTable)

    list_columns = [
        'connector.type', 'connector.name', 'table.database', 'table',
        'datasource', 'sql_table_name', 'downloaded_rows', 'progress',
        'sync_last', 'sync_last_time', 'status', 'save_in_prt', 'sync_field',
        'repr_sync_periodic', 'sync_next_time', 'is_valid', 'is_active',
        'is_scheduled'
        # 'table.sql'
    ]

    add_columns = [
        'connector', 'table', 'datasource', 'name', 'save_in_prt',
        'calculate_progress', 'sync_field', 'sync_last', 'chunk_size',
        'sync_periodic', 'sync_periodic_hour'
    ]

    edit_schema = ['schema']

    edit_columns = [
        'calculate_progress', 'save_in_prt', 'sync_field', 'sync_last',
        'chunk_size', 'sync_periodic', 'sync_periodic_hour', 'is_active'
    ]

    edit_columns = edit_columns
    # edit_columns = add_columns + ['sync_last_time', 'sync_next_time']

    label_columns = {
        'table.database': 'Instance',
        'connector': 'DataSource Connector',
        'table': 'DataSource SQLTable',
        'name': '_etl_{name}',
    }

    description_columns = {
        'table': _('<a href="/tablemodelview/list/">Create table </a>'),
    }

    etl_extra_fields = {

        # 'datasource': AJAXSelectField('Extra Field2',
        #     description='Extra Field description',
        #     datamodel=datamodel,
        #     col_name='datasource',
        #     widget=Select2AJAXWidget(
        #         # master_id='connector',
        #         endpoint='/connectorview/api/column/add/datasource'
        #
        #         # http://127.0.0.1:5000/connectorview/api/column/add/get_admin_data_sources?_flt_0_id=1  # noqa
        #         # endpoint='/appsflyerconnectorview/api/column/add/contact_sub_group?_flt_0__id={{ID}}'  # noqa
        #
        #         # endpoint='/appsflyerconnectorview/api/read?_flt_0_id={{ID}}'   # noqa
        #     )
        # ),

        # 'datasource': AJAXSelectField('Extra Field2',
        #     description='Extra Field description',
        #     datamodel=datamodel,
        #     col_name='reports',
        #     widget=Select2SlaveAJAXWidget(
        #         master_id='connector',
        #         # endpoint='/appsflyerconnectorview/api/column/add/contact_sub_group?_flt_0__id={{ID}}')   # noqa
        #         endpoint='/appsflyerconnectorview/api/read?_flt_0_id={{ID}}'
        #     )
        # ),
        'sync_periodic':
        SelectField(choices=EtlPeriod.CHOICES,
                    widget=Select2Widget(),
                    coerce=int),
        'sync_periodic_hour':
        SelectField(choices=EtlPeriod.HOURS,
                    widget=Select2Widget(),
                    coerce=int,
                    description=_('Use if you select one of [Once a month, '
                                  'Once a week, Once a day] Sync Periodic')),
    }

    add_form_extra_fields = etl_extra_fields
    edit_form_extra_fields = etl_extra_fields

    # description_columns = {
    #     'sync_periodic_hour': (
    #         'Use if you select one of [Once a month, '
    #         'Once a week, Once a day] Periodic'
    #     ),
    # }

    # etl_extra_fields = {'schema': StringField(widget=BS3TextFieldROWidget())}
    # add_form_extra_fields = etl_extra_fields
    # edit_form_extra_fields = etl_extra_fields
    # related_views = [TableColumnInlineView]

    # actions
    @action('sync_etl', 'Sync', 'Sync data for this table', 'fa-play')
    def sync(self, item):
        """Call sync etl."""
        item[0].sync_delay()

        return redirect('etltableview/list/')

    @action('re_sync_etl', 'ReSync',
            'Clear data, and get new data for this table', 'fa-repeat')
    def re_sync(self, item):
        """Call ReSync etl."""
        item[0].clear()
        item[0].sync_delay()

        return redirect('etltableview/list/')

    @action('check_sql', 'Check_sql', 'Stop sync data and clear data',
            'fa-check')
    def check_sql(self, item):
        """Call stop etl."""
        print(item[0].remote_etl_sql())

        return redirect('etltableview/list/')

    @action('sync_etl_once', 'Sync once', 'Sync data for this table',
            'fa-step-forward')
    def sync_once(self, item):
        """Call test_sync_etl."""
        item[0].sync_once()

        return redirect('etltableview/list/')

    @action('sync_etl_stop', 'Sync stop', 'Stop sync data for this table',
            'fa-stop')
    def sync_stop(self, item):
        """Call stop etl."""
        item[0].stop()

        return redirect('etltableview/list/')

    @action('clear_etl', 'Clear', 'Stop sync data and clear data',
            'fa-trash-o')
    def clear_etl(self, item):
        """Call stop etl."""
        item[0].clear()

        return redirect('etltableview/list/')

    def pre_add(self, obj):
        """Check data before save"""

        if obj.name == '':
            raise Exception('Enter a table name')

        obj.create_table()
        obj.sync_next_time = obj.get_next_sync()

        # obj.cccc()

    # def post_add(self):
    #     # create ds table
    #
    #     return

    def pre_update(self, obj):
        obj.sync_next_time = obj.get_next_sync()

    def pre_delete(self, obj):
        logging.info('pre_delete')

        if obj.exist_table():
            obj.delete_table()
Esempio n. 7
0
class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.DruidDatasource)
    include_route_methods = RouteMethod.CRUD_SET
    list_title = _("Druid Datasources")
    show_title = _("Show Druid Datasource")
    add_title = _("Add Druid Datasource")
    edit_title = _("Edit Druid Datasource")

    list_columns = ["datasource_link", "cluster", "changed_by_", "modified"]
    order_columns = ["datasource_link", "modified"]
    related_views = [DruidColumnInlineView, DruidMetricInlineView]
    edit_columns = [
        "datasource_name",
        "cluster",
        "description",
        "owners",
        "is_hidden",
        "filter_select_enabled",
        "fetch_values_from",
        "default_endpoint",
        "offset",
        "cache_timeout",
    ]
    search_columns = ("datasource_name", "cluster", "description", "owners")
    add_columns = edit_columns
    show_columns = add_columns + ["perm", "slices"]
    page_size = 500
    base_order = ("datasource_name", "asc")
    description_columns = {
        "slices": _(
            "The list of charts associated with this table. By "
            "altering this datasource, you may change how these associated "
            "charts behave. "
            "Also note that charts need to point to a datasource, so "
            "this form will fail at saving if removing charts from a "
            "datasource. If you want to change the datasource for a chart, "
            "overwrite the chart from the 'explore view'"
        ),
        "offset": _("Timezone offset (in hours) for this datasource"),
        "description": Markup(
            'Supports <a href="'
            'https://daringfireball.net/projects/markdown/">markdown</a>'
        ),
        "fetch_values_from": _(
            "Time expression to use as a predicate when retrieving "
            "distinct values to populate the filter component. "
            "Only applies when `Enable Filter Select` is on. If "
            "you enter `7 days ago`, the distinct list of values in "
            "the filter will be populated based on the distinct value over "
            "the past week"
        ),
        "filter_select_enabled": _(
            "Whether to populate the filter's dropdown in the explore "
            "view's filter section with a list of distinct values fetched "
            "from the backend on the fly"
        ),
        "default_endpoint": _(
            "Redirects to this endpoint when clicking on the datasource "
            "from the datasource list"
        ),
        "cache_timeout": _(
            "Duration (in seconds) of the caching timeout for this datasource. "
            "A timeout of 0 indicates that the cache never expires. "
            "Note this defaults to the cluster timeout if undefined."
        ),
    }
    base_filters = [["id", DatasourceFilter, lambda: []]]
    label_columns = {
        "slices": _("Associated Charts"),
        "datasource_link": _("Data Source"),
        "cluster": _("Cluster"),
        "description": _("Description"),
        "owners": _("Owners"),
        "is_hidden": _("Is Hidden"),
        "filter_select_enabled": _("Enable Filter Select"),
        "default_endpoint": _("Default Endpoint"),
        "offset": _("Time Offset"),
        "cache_timeout": _("Cache Timeout"),
        "datasource_name": _("Datasource Name"),
        "fetch_values_from": _("Fetch Values From"),
        "changed_by_": _("Changed By"),
        "modified": _("Modified"),
    }
    edit_form_extra_fields = {
        "cluster": QuerySelectField(
            "Cluster",
            query_factory=lambda: db.session.query(models.DruidCluster),
            widget=Select2Widget(extra_classes="readonly"),
        ),
        "datasource_name": StringField(
            "Datasource Name", widget=BS3TextFieldROWidget()
        ),
    }

    def pre_add(self, item: "DruidDatasourceModelView") -> None:
        with db.session.no_autoflush:
            query = db.session.query(models.DruidDatasource).filter(
                models.DruidDatasource.datasource_name == item.datasource_name,
                models.DruidDatasource.cluster_id == item.cluster_id,
            )
            if db.session.query(query.exists()).scalar():
                raise Exception(get_datasource_exist_error_msg(item.full_name))

    def post_add(self, item: "DruidDatasourceModelView") -> None:
        item.refresh_metrics()
        security_manager.add_permission_view_menu("datasource_access", item.get_perm())
        if item.schema:
            security_manager.add_permission_view_menu("schema_access", item.schema_perm)

    def post_update(self, item: "DruidDatasourceModelView") -> None:
        self.post_add(item)

    def _delete(self, pk: int) -> None:
        DeleteMixin._delete(self, pk)
Esempio n. 8
0
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.DruidColumn)

    list_title = _('Columns')
    show_title = _('Show Druid Column')
    add_title = _('Add Druid Column')
    edit_title = _('Edit Druid Column')

    list_widget = ListWidgetWithCheckboxes

    edit_columns = [
        'column_name', 'verbose_name', 'description', 'dimension_spec_json',
        'datasource', 'groupby', 'filterable'
    ]
    add_columns = edit_columns
    list_columns = [
        'column_name', 'verbose_name', 'type', 'groupby', 'filterable'
    ]
    can_delete = False
    page_size = 500
    label_columns = {
        'column_name': _('Column'),
        'type': _('Type'),
        'datasource': _('Datasource'),
        'groupby': _('Groupable'),
        'filterable': _('Filterable'),
    }
    description_columns = {
        'filterable':
        _('Whether this column is exposed in the `Filters` section '
          'of the explore view.'),
        'dimension_spec_json':
        utils.markdown(
            'this field can be used to specify  '
            'a `dimensionSpec` as documented [here]'
            '(http://druid.io/docs/latest/querying/dimensionspecs.html). '
            'Make sure to input valid JSON and that the '
            '`outputName` matches the `column_name` defined '
            'above.', True),
    }

    add_form_extra_fields = {
        'datasource':
        QuerySelectField(
            'Datasource',
            query_factory=lambda: db.session().query(models.DruidDatasource),
            allow_blank=True,
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    edit_form_extra_fields = add_form_extra_fields

    def pre_update(self, col):
        # If a dimension spec JSON is given, ensure that it is
        # valid JSON and that `outputName` is specified
        if col.dimension_spec_json:
            try:
                dimension_spec = json.loads(col.dimension_spec_json)
            except ValueError as e:
                raise ValueError('Invalid Dimension Spec JSON: ' + str(e))
            if not isinstance(dimension_spec, dict):
                raise ValueError('Dimension Spec must be a JSON object')
            if 'outputName' not in dimension_spec:
                raise ValueError(
                    'Dimension Spec does not contain `outputName`')
            if 'dimension' not in dimension_spec:
                raise ValueError('Dimension Spec is missing `dimension`')
            # `outputName` should be the same as the `column_name`
            if dimension_spec['outputName'] != col.column_name:
                raise ValueError(
                    '`outputName` [{}] unequal to `column_name` [{}]'.format(
                        dimension_spec['outputName'], col.column_name))

    def post_update(self, col):
        col.refresh_metrics()

    def post_add(self, col):
        self.post_update(col)
Esempio n. 9
0
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.SqlMetric)

    list_title = _("Metrics")
    show_title = _("Show Metric")
    add_title = _("Add Metric")
    edit_title = _("Edit Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "expression",
        "table",
        "d3format",
        "is_restricted",
        "warning_text",
    ]
    description_columns = {
        "expression":
        utils.markdown(
            "a valid, *aggregating* SQL expression as supported by the "
            "underlying backend. Example: `count(DISTINCT userid)`",
            True,
        ),
        "is_restricted":
        _("Whether access to this metric is restricted "
          "to certain roles. Only roles with the permission "
          "'metric access on XXX (the name of this metric)' "
          "are allowed to access this metric"),
        "d3format":
        utils.markdown(
            "d3 formatting string as defined [here]"
            "(https://github.com/d3/d3-format/blob/master/README.md#format). "
            "For instance, this default formatting applies in the Table "
            "visualization and allow for different metric to use different "
            "formats",
            True,
        ),
    }
    add_columns = edit_columns
    page_size = 500
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "expression": _("SQL Expression"),
        "table": _("Table"),
        "d3format": _("D3 Format"),
        "is_restricted": _("Is Restricted"),
        "warning_text": _("Warning Message"),
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields

    def post_add(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu("metric_access",
                                                      metric.get_perm())

    def post_update(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu("metric_access",
                                                      metric.get_perm())
Esempio n. 10
0
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.DruidMetric)

    list_title = _('Metrics')
    show_title = _('Show Druid Metric')
    add_title = _('Add Druid Metric')
    edit_title = _('Edit Druid Metric')

    list_columns = ['metric_name', 'verbose_name', 'metric_type']
    edit_columns = [
        'metric_name', 'description', 'verbose_name', 'metric_type', 'json',
        'datasource', 'd3format', 'is_restricted', 'warning_text'
    ]
    add_columns = edit_columns
    page_size = 500
    validators_columns = {
        'json': [validate_json],
    }
    description_columns = {
        'metric_type':
        utils.markdown(
            'use `postagg` as the metric type if you are defining a '
            '[Druid Post Aggregation]'
            '(http://druid.io/docs/latest/querying/post-aggregations.html)',
            True),
        'is_restricted':
        _('Whether access to this metric is restricted '
          'to certain roles. Only roles with the permission '
          "'metric access on XXX (the name of this metric)' "
          'are allowed to access this metric'),
    }
    label_columns = {
        'metric_name': _('Metric'),
        'description': _('Description'),
        'verbose_name': _('Verbose Name'),
        'metric_type': _('Type'),
        'json': _('JSON'),
        'datasource': _('Druid Datasource'),
        'warning_text': _('Warning Message'),
        'is_restricted': _('Is Restricted'),
    }

    add_form_extra_fields = {
        'datasource':
        QuerySelectField(
            'Datasource',
            query_factory=lambda: db.session().query(models.DruidDatasource),
            allow_blank=True,
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    edit_form_extra_fields = add_form_extra_fields

    def post_add(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu('metric_access',
                                                      metric.get_perm())

    def post_update(self, metric):
        if metric.is_restricted:
            security_manager.add_permission_view_menu('metric_access',
                                                      metric.get_perm())
Esempio n. 11
0
class DruidClusterModelView(SupersetModelView, DeleteMixin,
                            YamlExportMixin):  # noqa
    datamodel = SQLAInterface(models.DruidCluster)

    list_title = _('Druid Clusters')
    show_title = _('Show Druid Cluster')
    add_title = _('Add Druid Cluster')
    edit_title = _('Edit Druid Cluster')

    add_columns = [
        'verbose_name',
        'broker_host',
        'broker_port',
        'broker_user',
        'broker_pass',
        'broker_endpoint',
        'cache_timeout',
        'cluster_name',
    ]
    edit_columns = add_columns
    list_columns = ['cluster_name', 'metadata_last_refreshed']
    search_columns = ('cluster_name', )
    label_columns = {
        'cluster_name': _('Cluster'),
        'broker_host': _('Broker Host'),
        'broker_port': _('Broker Port'),
        'broker_user': _('Broker Username'),
        'broker_pass': _('Broker Password'),
        'broker_endpoint': _('Broker Endpoint'),
        'verbose_name': _('Verbose Name'),
        'cache_timeout': _('Cache Timeout'),
        'metadata_last_refreshed': _('Metadata Last Refreshed'),
    }
    description_columns = {
        'cache_timeout':
        _('Duration (in seconds) of the caching timeout for this cluster. '
          'A timeout of 0 indicates that the cache never expires. '
          'Note this defaults to the global timeout if undefined.'),
        'broker_user':
        _(
            'Druid supports basic authentication. See '
            '[auth](http://druid.io/docs/latest/design/auth.html) and '
            'druid-basic-security extension', ),
        'broker_pass':
        _(
            'Druid supports basic authentication. See '
            '[auth](http://druid.io/docs/latest/design/auth.html) and '
            'druid-basic-security extension', ),
    }

    edit_form_extra_fields = {
        'cluster_name':
        QuerySelectField(
            'Cluster',
            query_factory=lambda: db.session().query(models.DruidCluster),
            widget=Select2Widget(extra_classes='readonly'),
        ),
    }

    def pre_add(self, cluster):
        security_manager.add_permission_view_menu('database_access',
                                                  cluster.perm)

    def pre_update(self, cluster):
        self.pre_add(cluster)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)
Esempio n. 12
0
class Service_Pipeline_ModelView_Base():

    label_title = '任务流'
    datamodel = SQLAInterface(Service_Pipeline)
    check_redirect_list_url = '/service_pipeline_modelview/list/'

    base_permissions = [
        'can_show', 'can_edit', 'can_list', 'can_delete', 'can_add'
    ]
    base_order = ("changed_on", "desc")
    # order_columns = ['id','changed_on']
    order_columns = ['id']

    list_columns = [
        'project', 'service_pipeline_url', 'creator', 'modified',
        'operate_html'
    ]
    add_columns = [
        'project', 'name', 'describe', 'namespace', 'images', 'env',
        'resource_memory', 'resource_cpu', 'resource_gpu', 'replicas',
        'dag_json', 'alert_status', 'alert_user', 'parameter'
    ]
    show_columns = [
        'project', 'name', 'describe', 'namespace', 'run_id', 'created_by',
        'changed_by', 'created_on', 'changed_on', 'expand_html',
        'parameter_html'
    ]
    edit_columns = add_columns

    base_filters = [["id", Service_Pipeline_Filter, lambda: []]]  # 设置权限过滤器
    conv = GeneralModelConverter(datamodel)

    add_form_extra_fields = {
        "name":
        StringField(_(datamodel.obj.lab('name')),
                    description="英文名(字母、数字、- 组成),最长50个字符",
                    widget=BS3TextFieldWidget(),
                    validators=[
                        Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"),
                        Length(1, 54),
                        DataRequired()
                    ]),
        "project":
        QuerySelectField(_(datamodel.obj.lab('project')),
                         query_factory=filter_join_org_project,
                         allow_blank=True,
                         widget=Select2Widget()),
        "dag_json":
        StringField(
            _(datamodel.obj.lab('dag_json')),
            widget=MyBS3TextAreaFieldWidget(
                rows=10),  # 传给widget函数的是外层的field对象,以及widget函数的参数
        ),
        "namespace":
        StringField(_(datamodel.obj.lab('namespace')),
                    description="部署task所在的命名空间(目前无需填写)",
                    default='service',
                    widget=BS3TextFieldWidget()),
        "images":
        StringField(
            _(datamodel.obj.lab('images')),
            default='ccr.ccs.tencentyun.com/cube-studio/service-pipeline',
            description="推理服务镜像",
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]),
        "node_selector":
        StringField(_(datamodel.obj.lab('node_selector')),
                    description="部署task所在的机器(目前无需填写)",
                    widget=BS3TextFieldWidget(),
                    default=datamodel.obj.node_selector.default.arg),
        "image_pull_policy":
        SelectField(
            _(datamodel.obj.lab('image_pull_policy')),
            description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)",
            widget=Select2Widget(),
            choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']]),
        "alert_status":
        MySelectMultipleField(label=_(datamodel.obj.lab('alert_status')),
                              widget=Select2ManyWidget(),
                              choices=[[x, x] for x in [
                                  'Created', 'Pending', 'Running', 'Succeeded',
                                  'Failed', 'Unknown', 'Waiting', 'Terminated'
                              ]],
                              description="选择通知状态"),
        "alert_user":
        StringField(label=_(datamodel.obj.lab('alert_user')),
                    widget=BS3TextFieldWidget(),
                    description="选择通知用户,每个用户使用逗号分隔"),
        "label":
        StringField(_(datamodel.obj.lab('label')),
                    description='中文名',
                    widget=BS3TextFieldWidget(),
                    validators=[DataRequired()]),
        "resource_memory":
        StringField(_(datamodel.obj.lab('resource_memory')),
                    default=Service_Pipeline.resource_memory.default.arg,
                    description='内存的资源使用限制,示例1G,10G, 最大100G,如需更多联系管路员',
                    widget=BS3TextFieldWidget(),
                    validators=[DataRequired()]),
        "resource_cpu":
        StringField(_(datamodel.obj.lab('resource_cpu')),
                    default=Service_Pipeline.resource_cpu.default.arg,
                    description='cpu的资源使用限制(单位核),示例 0.4,10,最大50核,如需更多联系管路员',
                    widget=BS3TextFieldWidget(),
                    validators=[DataRequired()]),
        "resource_gpu":
        StringField(
            _(datamodel.obj.lab('resource_gpu')),
            default=0,
            description=
            'gpu的资源使用限制(单位卡),示例:1,2,训练任务每个容器独占整卡。申请具体的卡型号,可以类似 1(V100),目前支持T4/V100/A100/VGPU',
            widget=BS3TextFieldWidget()),
        "replicas":
        StringField(_(datamodel.obj.lab('replicas')),
                    default=Service_Pipeline.replicas.default.arg,
                    description='pod副本数,用来配置高可用',
                    widget=BS3TextFieldWidget(),
                    validators=[DataRequired()]),
        "env":
        StringField(_(datamodel.obj.lab('env')),
                    default=Service_Pipeline.env.default.arg,
                    description=
                    '使用模板的task自动添加的环境变量,支持模板变量。书写格式:每行一个环境变量env_key=env_value',
                    widget=MyBS3TextAreaFieldWidget()),
    }

    edit_form_extra_fields = add_form_extra_fields

    # 检测是否具有编辑权限,只有creator和admin可以编辑
    def check_edit_permission(self, item):
        user_roles = [role.name.lower() for role in list(get_user_roles())]
        if "admin" in user_roles:
            return True
        if g.user and g.user.username and hasattr(item, 'created_by'):
            if g.user.username == item.created_by.username:
                return True
        flash('just creator can edit/delete ', 'warning')
        return False

    # 验证args参数
    @pysnooper.snoop(watch_explode=('item'))
    def service_pipeline_args_check(self, item):
        core.validate_str(item.name, 'name')
        if not item.dag_json:
            item.dag_json = '{}'
        core.validate_json(item.dag_json)

        # 校验task的关系,没有闭环,并且顺序要对。没有配置的,自动没有上游,独立
        # @pysnooper.snoop()
        def order_by_upstream(dag_json):
            order_dag = {}
            tasks_name = list(dag_json.keys())  # 如果没有配全的话,可能只有局部的task
            i = 0
            while tasks_name:
                i += 1
                if i > 100:  # 不会有100个依赖关系
                    break
                for task_name in tasks_name:
                    # 没有上游的情况
                    if not dag_json[task_name]:
                        order_dag[task_name] = dag_json[task_name]
                        tasks_name.remove(task_name)
                        continue
                    # 没有上游的情况
                    elif 'upstream' not in dag_json[
                            task_name] or not dag_json[task_name]['upstream']:
                        order_dag[task_name] = dag_json[task_name]
                        tasks_name.remove(task_name)
                        continue
                    # 如果有上游依赖的话,先看上游任务是否已经加到里面了。
                    upstream_all_ready = True
                    for upstream_task_name in dag_json[task_name]['upstream']:
                        if upstream_task_name not in order_dag:
                            upstream_all_ready = False
                    if upstream_all_ready:
                        order_dag[task_name] = dag_json[task_name]
                        tasks_name.remove(task_name)
                    else:
                        dag_json[task_name]['upstream'] = []
                        order_dag[task_name] = dag_json[task_name]
                        tasks_name.remove(task_name)

            if list(dag_json.keys()).sort() != list(order_dag.keys()).sort():
                flash('dag service pipeline 存在循环或未知上游', category='warning')
                raise MyappException('dag service pipeline 存在循环或未知上游')
            return order_dag

        # 配置上缺少的默认上游
        dag_json = json.loads(item.dag_json)
        item.dag_json = json.dumps(order_by_upstream(copy.deepcopy(dag_json)),
                                   ensure_ascii=False,
                                   indent=4)

        # raise Exception('args is not valid')

    # @pysnooper.snoop()
    def pre_add(self, item):
        item.name = item.name.replace('_', '-')[0:54].lower().strip('-')
        # item.alert_status = ','.join(item.alert_status)
        # self.service_pipeline_args_check(item)
        item.create_datetime = datetime.datetime.now()
        item.change_datetime = datetime.datetime.now()
        item.parameter = json.dumps({}, indent=4, ensure_ascii=False)
        item.volume_mount = item.project.volume_mount + ",%s(configmap):/config/" % item.name

    # @pysnooper.snoop()
    def pre_update(self, item):
        if item.expand:
            core.validate_json(item.expand)
            item.expand = json.dumps(json.loads(item.expand),
                                     indent=4,
                                     ensure_ascii=False)
        else:
            item.expand = '{}'
        item.name = item.name.replace('_', '-')[0:54].lower()
        item.alert_status = ','.join(item.alert_status)
        # self.service_pipeline_args_check(item)
        item.change_datetime = datetime.datetime.now()
        item.parameter = json.dumps(
            json.loads(item.parameter), indent=4,
            ensure_ascii=False) if item.parameter else '{}'
        item.dag_json = json.dumps(
            json.loads(item.dag_json), indent=4,
            ensure_ascii=False) if item.dag_json else '{}'
        item.volume_mount = item.project.volume_mount + ",%s(configmap):/config/" % item.name

    @expose("/my/list/")
    def my(self):
        try:
            user_id = g.user.id
            if user_id:
                service_pipelines = db.session.query(
                    Service_Pipeline).filter_by(created_by_fk=user_id).all()
                back = []
                for service_pipeline in service_pipelines:
                    back.append(service_pipeline.to_json())
                return json_response(message='success', status=0, result=back)
        except Exception as e:
            print(e)
            return json_response(message=str(e), status=-1, result={})

    def check_service_pipeline_perms(user_fun):
        # @pysnooper.snoop()
        def wraps(*args, **kwargs):
            service_pipeline_id = int(kwargs.get('service_pipeline_id', '0'))
            if not service_pipeline_id:
                response = make_response("service_pipeline_id not exist")
                response.status_code = 404
                return response

            user_roles = [role.name.lower() for role in g.user.roles]
            if "admin" in user_roles:
                return user_fun(*args, **kwargs)

            join_projects_id = security_manager.get_join_projects_id(
                db.session)
            service_pipeline = db.session.query(Service_Pipeline).filter_by(
                id=service_pipeline_id).first()
            if service_pipeline.project.id in join_projects_id:
                return user_fun(*args, **kwargs)

            response = make_response("no perms to run pipeline %s" %
                                     service_pipeline_id)
            response.status_code = 403
            return response

        return wraps

    # 构建同步服务
    def build_http(self, service_pipeline):
        pass

    # 构建异步服务
    @pysnooper.snoop()
    def build_mq_consumer(self, service_pipeline):
        namespace = conf.get('SERVICE_PIPELINE_NAMESPACE')
        name = service_pipeline.name
        command = service_pipeline.command
        image_secrets = conf.get('HUBSECRET', [])
        user_hubsecrets = db.session.query(Repository.hubsecret).filter(
            Repository.created_by_fk == g.user.id).all()
        if user_hubsecrets:
            for hubsecret in user_hubsecrets:
                if hubsecret[0] not in image_secrets:
                    image_secrets.append(hubsecret[0])

        from myapp.utils.py.py_k8s import K8s
        k8s_client = K8s(service_pipeline.project.cluster.get(
            'KUBECONFIG', ''))
        dag_json = service_pipeline.dag_json if service_pipeline.dag_json else '{}'

        # 生成服务使用的configmap

        config_data = {"dag.json": dag_json}
        k8s_client.create_configmap(namespace=namespace,
                                    name=name,
                                    data=config_data,
                                    labels={'app': name})
        env = service_pipeline.env
        if conf.get('SERVICE_PIPELINE_JAEGER', ''):
            env['JAEGER_HOST'] = conf.get('SERVICE_PIPELINE_JAEGER', '')
            env['SERVICE_NAME'] = name

        k8s_client.create_deployment(
            namespace=namespace,
            name=name,
            replicas=service_pipeline.replicas,
            labels={
                "app": name,
                "username": service_pipeline.created_by.username
            },
            # command=['sh','-c',command] if command else None,
            command=['bash', '-c', "python mq-pipeline/cube_kafka.py"],
            args=None,
            volume_mount=service_pipeline.volume_mount,
            working_dir=service_pipeline.working_dir,
            node_selector=service_pipeline.get_node_selector(),
            resource_memory=service_pipeline.resource_memory,
            resource_cpu=service_pipeline.resource_cpu,
            resource_gpu=service_pipeline.resource_gpu
            if service_pipeline.resource_gpu else '',
            image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'),
            image_pull_secrets=image_secrets,
            image=service_pipeline.images,
            hostAliases=conf.get('HOSTALIASES', ''),
            env=env,
            privileged=False,
            accounts=None,
            username=service_pipeline.created_by.username,
            ports=None)

        pass

    # 只能有一个入口。不能同时接口两个队列
    # # @event_logger.log_this
    @expose("/run_service_pipeline/<service_pipeline_id>",
            methods=["GET", "POST"])
    @check_service_pipeline_perms
    def run_service_pipeline(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()
        dag_json = json.loads(service_pipeline.dag_json)
        root_nodes_name = service_pipeline.get_root_node_name()
        self.clear(service_pipeline_id)
        if root_nodes_name:
            root_node_name = root_nodes_name[0]
            root_node = dag_json[root_node_name]
            # 构建异步
            if root_node['template-group'] == 'endpoint' and root_node[
                    'template'] == 'mq':
                self.build_mq_consumer(service_pipeline)

            # 构建同步
            if root_node['template-group'] == 'endpoint' and root_node[
                    'template'] == 'gateway':
                self.build_http(service_pipeline)

        return redirect("/service_pipeline_modelview/web/log/%s" %
                        service_pipeline_id)
        # return redirect(run_url)

    # # @event_logger.log_this
    @expose("/web/<service_pipeline_id>", methods=["GET"])
    def web(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()

        # service_pipeline.dag_json = service_pipeline.fix_dag_json()
        # service_pipeline.expand = json.dumps(service_pipeline.fix_position(), indent=4, ensure_ascii=False)

        db.session.commit()
        print(service_pipeline_id)
        data = {
            "url":
            '/static/appbuilder/vison/index.html?pipeline_id=%s' %
            service_pipeline_id  # 前后端集成完毕,这里需要修改掉
        }
        # 返回模板
        return self.render_template('link.html', data=data)

    # # @event_logger.log_this
    @expose("/web/log/<service_pipeline_id>", methods=["GET"])
    def web_log(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()
        if service_pipeline.run_id:
            data = {
                "url":
                service_pipeline.project.cluster.get('PIPELINE_URL') +
                "runs/details/" + service_pipeline.run_id,
                "target":
                "div.page_f1flacxk:nth-of-type(0)",  # "div.page_f1flacxk:nth-of-type(0)",
                "delay":
                500,
                "loading":
                True
            }
            # 返回模板
            if service_pipeline.project.cluster['NAME'] == conf.get(
                    'ENVIRONMENT'):
                return self.render_template('link.html', data=data)
            else:
                return self.render_template('external_link.html', data=data)
        else:
            flash('no running instance', 'warning')
            return redirect('/service_pipeline_modelview/web/%s' %
                            service_pipeline.id)

    # 链路跟踪
    @expose("/web/monitoring/<service_pipeline_id>", methods=["GET"])
    def web_monitoring(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=int(service_pipeline_id)).first()
        if service_pipeline.run_id:
            url = service_pipeline.project.cluster.get(
                'GRAFANA_HOST', '').strip('/') + conf.get(
                    'GRAFANA_TASK_PATH') + service_pipeline.name
            return redirect(url)
        else:
            flash('no running instance', 'warning')
            return redirect('/service_pipeline_modelview/web/%s' %
                            service_pipeline.id)

    # # @event_logger.log_this
    @expose("/web/pod/<service_pipeline_id>", methods=["GET"])
    def web_pod(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()
        data = {
            "url":
            service_pipeline.project.cluster.get('K8S_DASHBOARD_CLUSTER', '') +
            '#/search?namespace=%s&q=%s' %
            (conf.get('SERVICE_PIPELINE_NAMESPACE'),
             service_pipeline.name.replace('_', '-')),
            "target":
            "div.kd-chrome-container.kd-bg-background",
            "delay":
            500,
            "loading":
            True
        }
        # 返回模板
        if service_pipeline.project.cluster['NAME'] == conf.get('ENVIRONMENT'):
            return self.render_template('link.html', data=data)
        else:
            return self.render_template('external_link.html', data=data)

    @expose('/clear/<service_id>', methods=['POST', "GET"])
    def clear(self, service_pipeline_id):
        service_pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()

        from myapp.utils.py.py_k8s import K8s
        k8s_client = K8s(service_pipeline.project.cluster.get(
            'KUBECONFIG', ''))
        namespace = conf.get('SERVICE_PIPELINE_NAMESPACE')
        k8s_client.delete_deployment(namespace=namespace,
                                     name=service_pipeline.name)

        flash('服务清理完成', category='warning')
        return redirect('/service_pipeline_modelview/list/')

    @expose("/config/<service_pipeline_id>", methods=("GET", 'POST'))
    def pipeline_config(self, service_pipeline_id):
        print(service_pipeline_id)
        pipeline = db.session.query(Service_Pipeline).filter_by(
            id=service_pipeline_id).first()
        if not pipeline:
            return jsonify({"status": 1, "message": "服务流不存在", "result": {}})
        if request.method.lower() == 'post':
            data = request.get_json()
            request_config = data.get('config', {})
            request_dag = data.get('dag_json', {})
            if request_config:
                pipeline.config = json.dumps(request_config,
                                             indent=4,
                                             ensure_ascii=False)
            if request_dag:
                pipeline.dag_json = json.dumps(request_dag,
                                               indent=4,
                                               ensure_ascii=False)
            db.session.commit()
        config = {
            "id":
            pipeline.id,
            "name":
            pipeline.name,
            "label":
            pipeline.describe,
            "project":
            pipeline.project.describe,
            "pipeline_ui_config": {
                "alert": {
                    "alert_user": {
                        "type": "str",
                        "item_type": "str",
                        "label": "报警用户",
                        "require": 1,
                        "choice": [],
                        "range": "",
                        "default": "",
                        "placeholder": "报警用户名,逗号分隔",
                        "describe": "报警用户,逗号分隔",
                        "editable": 1,
                        "condition": "",
                        "sub_args": {}
                    }
                }
            },
            "pipeline_jump_button": [{
                "name":
                "资源查看",
                "action_url":
                "",
                "icon_svg":
                '<svg t="1644980982636" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="2611" width="128" height="128"><path d="M913.937279 113.328092c-32.94432-32.946366-76.898391-51.089585-123.763768-51.089585s-90.819448 18.143219-123.763768 51.089585L416.737356 362.999454c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768s18.143219 90.819448 51.087539 123.763768c25.406646 25.40767 57.58451 42.144866 93.053326 48.403406 1.76418 0.312108 3.51915 0.463558 5.249561 0.463558 14.288424 0 26.951839-10.244318 29.519314-24.802896 2.879584-16.322757-8.016581-31.889291-24.339338-34.768875-23.278169-4.106528-44.38386-15.081487-61.039191-31.736818-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l15.864316-15.864316c-0.267083 1.121544-0.478907 2.267647-0.6191 3.440355-1.955538 16.45988 9.800203 31.386848 26.260084 33.344432 25.863041 3.072989 49.213865 14.378475 67.527976 32.692586 21.608134 21.608134 33.509185 50.489928 33.509185 81.322144s-11.901051 59.71401-33.509185 81.322144L318.53987 871.368764c-21.61018 21.61018-50.489928 33.511231-81.322144 33.511231-30.832216 0-59.711963-11.901051-81.322144-33.511231-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l169.43597-169.438017c11.720949-11.718903 11.720949-30.722722 0-42.441625-11.718903-11.718903-30.722722-11.718903-42.441625 0L113.452935 666.282852c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768 0 46.865377 18.143219 90.819448 51.089585 123.763768 32.94432 32.946366 76.898391 51.091632 123.763768 51.091632s90.819448-18.145266 123.763768-51.091632l249.673409-249.671363c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768-0.002047-46.865377-18.145266-90.819448-51.089585-123.763768-27.5341-27.536146-64.073294-45.240367-102.885252-49.854455-3.618411-0.428765-7.161097-0.196475-10.508331 0.601704l211.589023-211.589023c21.61018-21.61018 50.489928-33.509185 81.322144-33.509185s59.711963 11.899004 81.322144 33.509185c21.61018 21.61018 33.509185 50.489928 33.509185 81.322144s-11.899004 59.711963-33.509185 81.322144l-150.180418 150.182464c-11.720949 11.718903-11.720949 30.722722 0 42.441625 11.718903 11.718903 30.722722 11.718903 42.441625 0l150.180418-150.182464c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768C965.026864 190.226482 946.882622 146.272411 913.937279 113.328092z" p-id="2612" fill="#225ed2"></path></svg>'
            }, {
                "name":
                "链路追踪",
                "action_url":
                "http://swallow.music.woa.com/myapp/swallow#/?pathUrl=%2Fswallow%2FdispatchOps%2FTaskListManager",
                "icon_svg":
                '<svg t="1644980982636" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="2611" width="128" height="128"><path d="M913.937279 113.328092c-32.94432-32.946366-76.898391-51.089585-123.763768-51.089585s-90.819448 18.143219-123.763768 51.089585L416.737356 362.999454c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768s18.143219 90.819448 51.087539 123.763768c25.406646 25.40767 57.58451 42.144866 93.053326 48.403406 1.76418 0.312108 3.51915 0.463558 5.249561 0.463558 14.288424 0 26.951839-10.244318 29.519314-24.802896 2.879584-16.322757-8.016581-31.889291-24.339338-34.768875-23.278169-4.106528-44.38386-15.081487-61.039191-31.736818-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l15.864316-15.864316c-0.267083 1.121544-0.478907 2.267647-0.6191 3.440355-1.955538 16.45988 9.800203 31.386848 26.260084 33.344432 25.863041 3.072989 49.213865 14.378475 67.527976 32.692586 21.608134 21.608134 33.509185 50.489928 33.509185 81.322144s-11.901051 59.71401-33.509185 81.322144L318.53987 871.368764c-21.61018 21.61018-50.489928 33.511231-81.322144 33.511231-30.832216 0-59.711963-11.901051-81.322144-33.511231-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l169.43597-169.438017c11.720949-11.718903 11.720949-30.722722 0-42.441625-11.718903-11.718903-30.722722-11.718903-42.441625 0L113.452935 666.282852c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768 0 46.865377 18.143219 90.819448 51.089585 123.763768 32.94432 32.946366 76.898391 51.091632 123.763768 51.091632s90.819448-18.145266 123.763768-51.091632l249.673409-249.671363c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768-0.002047-46.865377-18.145266-90.819448-51.089585-123.763768-27.5341-27.536146-64.073294-45.240367-102.885252-49.854455-3.618411-0.428765-7.161097-0.196475-10.508331 0.601704l211.589023-211.589023c21.61018-21.61018 50.489928-33.509185 81.322144-33.509185s59.711963 11.899004 81.322144 33.509185c21.61018 21.61018 33.509185 50.489928 33.509185 81.322144s-11.899004 59.711963-33.509185 81.322144l-150.180418 150.182464c-11.720949 11.718903-11.720949 30.722722 0 42.441625 11.718903 11.718903 30.722722 11.718903 42.441625 0l150.180418-150.182464c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768C965.026864 190.226482 946.882622 146.272411 913.937279 113.328092z" p-id="2612" fill="#225ed2"></path></svg>'
            }],
            "pipeline_run_button": [],
            "task_jump_button": [],
            "dag_json":
            json.loads(pipeline.dag_json),
            "config":
            json.loads(pipeline.config),
            "message":
            "success",
            "status":
            0
        }
        return jsonify(config)

    @expose("/template/list/")
    def template_list(self):

        all_template = {
            "message": "success",
            "templte_common_ui_config": {},
            "template_group_order": ["入口", "逻辑节点", "功能节点"],
            "templte_list": {
                "入口": [{
                    "template_name": "kafka",
                    "template_id": 1,
                    "templte_ui_config": {
                        "shell": {
                            "topic": {
                                "type": "str",
                                "item_type": "str",
                                "label": "topic",
                                "require": 1,
                                "choice": [],
                                "range": "",
                                "default": "predict",
                                "placeholder": "",
                                "describe": "kafka topic",
                                "editable": 1,
                                "condition": "",
                                "sub_args": {}
                            },
                            "consumer_num": {
                                "type": "str",
                                "item_type": "str",
                                "label": "消费者数目",
                                "require": 1,
                                "choice": [],
                                "range": "",
                                "default": "4",
                                "placeholder": "",
                                "describe": "消费者数目",
                                "editable": 1,
                                "condition": "",
                                "sub_args": {}
                            },
                            "bootstrap_servers": {
                                "type": "str",
                                "item_type": "str",
                                "label": "地址",
                                "require": 1,
                                "choice": [],
                                "range": "",
                                "default": "127.0.0.1:9092",
                                "placeholder": "",
                                "describe":
                                "xx.xx.xx.xx:9092,xx.xx.xx.xx:9092",
                                "editable": 1,
                                "condition": "",
                                "sub_args": {}
                            },
                            "group": {
                                "type": "str",
                                "item_type": "str",
                                "label": "分组",
                                "require": 1,
                                "choice": [],
                                "range": "",
                                "default": "predict",
                                "placeholder": "",
                                "describe": "消费者分组",
                                "editable": 1,
                                "condition": "",
                                "sub_args": {}
                            }
                        }
                    },
                    "username": g.user.username,
                    "changed_on": datetime.datetime.now(),
                    "created_on": datetime.datetime.now(),
                    "label": "kafka",
                    "describe": "消费kafka数据",
                    "help_url": "",
                    "pass_through": {
                        # 无论什么内容  通过task的字段透传回来
                    }
                }],
                "逻辑节点": [
                    {
                        "template_name": "switch",
                        "template_id": 2,
                        "templte_ui_config": {
                            "shell": {
                                "case": {
                                    "type": "text",
                                    "item_type": "str",
                                    "label": "表达式",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default":
                                    "int(input['node2'])<3:node4,node5:'3'\ndefault:node6:'0'",
                                    "placeholder": "",
                                    "describe": "条件:下游节点:输出     其中input为节点输入",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                }
                            }
                        },
                        "username": g.user.username,
                        "changed_on": datetime.datetime.now(),
                        "created_on": datetime.datetime.now(),
                        "label": "switch-case逻辑节点",
                        "describe": "控制数据的流量",
                        "help_url": "",
                        "pass_through": {
                            # 无论什么内容  通过task的字段透传回来
                        }
                    },
                ],
                "功能节点": [
                    {
                        "template_name": "http",
                        "template_id": 3,
                        "templte_ui_config": {
                            "shell": {
                                "method": {
                                    "type": "choice",
                                    "item_type": "str",
                                    "label": "请求方式",
                                    "require": 1,
                                    "choice": ["GET", "POST"],
                                    "range": "",
                                    "default": "POST",
                                    "placeholder": "",
                                    "describe": "请求方式",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                },
                                "url": {
                                    "type": "str",
                                    "item_type": "str",
                                    "label": "请求地址",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default": "http://127.0.0.1:8080/api",
                                    "placeholder": "",
                                    "describe": "请求地址",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                },
                                "header": {
                                    "type": "text",
                                    "item_type": "str",
                                    "label": "请求头",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default": "{}",
                                    "placeholder": "",
                                    "describe": "请求头",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                },
                                "timeout": {
                                    "type": "int",
                                    "item_type": "str",
                                    "label": "请求超时",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default": "300",
                                    "placeholder": "",
                                    "describe": "请求超时",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                },
                                "date": {
                                    "type": "text",
                                    "item_type": "str",
                                    "label": "请求内容",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default": "{}",
                                    "placeholder": "",
                                    "describe": "请求内容",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                }
                            }
                        },
                        "username": g.user.username,
                        "changed_on": datetime.datetime.now(),
                        "created_on": datetime.datetime.now(),
                        "label": "http请求",
                        "describe": "http请求",
                        "help_url": "",
                        "pass_through": {
                            # 无论什么内容  通过task的字段透传回来
                        }
                    },
                    {
                        "template_name": "自定义方法",
                        "template_id": 4,
                        "templte_ui_config": {
                            "shell": {
                                "sdk_path": {
                                    "type": "str",
                                    "item_type": "str",
                                    "label": "函数文件地址",
                                    "require": 1,
                                    "choice": [],
                                    "range": "",
                                    "default": "",
                                    "placeholder": "",
                                    "describe": "函数文件地址,文件名和python类型要相同",
                                    "editable": 1,
                                    "condition": "",
                                    "sub_args": {}
                                },
                            }
                        },
                        "username": g.user.username,
                        "changed_on": datetime.datetime.now(),
                        "created_on": datetime.datetime.now(),
                        "label": "http请求",
                        "describe": "http请求",
                        "help_url": "",
                        "pass_through": {
                            # 无论什么内容  通过task的字段透传回来
                        }
                    }
                ],
            },
            "status": 0
        }
        index = 1
        for group in all_template['templte_list']:
            for template in all_template['templte_list'][group]:
                template['template_id'] = index
                template['changed_on'] = datetime.datetime.now().strftime(
                    '%Y-%m-%d %H:%M:%S')
                template['created_on'] = datetime.datetime.now().strftime(
                    '%Y-%m-%d %H:%M:%S')
                template['username'] = g.user.username,
                index += 1

        return jsonify(all_template)
Esempio n. 13
0
    def set_column(self, notebook=None):
        # 对编辑进行处理
        self.add_form_extra_fields['name'] = StringField(
            _(self.datamodel.obj.lab('name')),
            default="%s-" % g.user.username + uuid.uuid4().hex[:4],
            description='英文名(字母、数字、-组成),最长50个字符',
            widget=MyBS3TextFieldWidget(readonly=True if notebook else False),
            validators=[
                DataRequired(),
                Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"),
                Length(1, 54)
            ]  # 注意不能以-开头和结尾
        )
        self.add_form_extra_fields['describe'] = StringField(
            _(self.datamodel.obj.lab('describe')),
            default='%s的个人notebook' % g.user.username,
            description='中文描述',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])

        # "project": QuerySelectField(
        #     _(datamodel.obj.lab('project')),
        #     query_factory=filter_join_org_project,
        #     allow_blank=True,
        #     widget=Select2Widget()
        # ),

        self.add_form_extra_fields['project'] = QuerySelectField(
            _(self.datamodel.obj.lab('project')),
            default='',
            description=_(r'部署项目组'),
            query_factory=filter_join_org_project,
            widget=MySelect2Widget(
                extra_classes="readonly" if notebook else None, new_web=False),
        )
        self.add_form_extra_fields['images'] = SelectField(
            _(self.datamodel.obj.lab('images')),
            description=_(r'notebook基础环境镜像,如果显示不准确,请删除新建notebook'),
            widget=MySelect2Widget(
                extra_classes="readonly" if notebook else None, new_web=False),
            choices=conf.get('NOTEBOOK_IMAGES', []),
            # validators=[DataRequired()]
        )

        self.add_form_extra_fields['node_selector'] = StringField(
            _(self.datamodel.obj.lab('node_selector')),
            default='cpu=true,notebook=true',
            description="部署task所在的机器",
            widget=BS3TextFieldWidget())
        self.add_form_extra_fields['image_pull_policy'] = SelectField(
            _(self.datamodel.obj.lab('image_pull_policy')),
            description="镜像拉取策略(Always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)",
            widget=Select2Widget(),
            choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']])
        self.add_form_extra_fields['volume_mount'] = StringField(
            _(self.datamodel.obj.lab('volume_mount')),
            default=notebook.project.volume_mount if notebook else '',
            description=
            '外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2',
            widget=BS3TextFieldWidget())
        self.add_form_extra_fields['working_dir'] = StringField(
            _(self.datamodel.obj.lab('working_dir')),
            default='/mnt',
            description="工作目录,如果为空,则使用Dockerfile中定义的workingdir",
            widget=BS3TextFieldWidget())
        self.add_form_extra_fields['resource_memory'] = StringField(
            _(self.datamodel.obj.lab('resource_memory')),
            default=Notebook.resource_memory.default.arg,
            description='内存的资源使用限制,示例:1G,20G',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])
        self.add_form_extra_fields['resource_cpu'] = StringField(
            _(self.datamodel.obj.lab('resource_cpu')),
            default=Notebook.resource_cpu.default.arg,
            description='cpu的资源使用限制(单位:核),示例:2',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()])

        self.add_form_extra_fields['resource_gpu'] = StringField(
            _(self.datamodel.obj.lab('resource_gpu')),
            default='0',
            description=
            'gpu的资源使用限gpu的资源使用限制(单位卡),示例:1,2,训练任务每个容器独占整卡。申请具体的卡型号,可以类似 1(V100),目前支持T4/V100/A100/VGPU',
            widget=BS3TextFieldWidget(),
            # choices=conf.get('GPU_CHOICES', [[]]),
            validators=[DataRequired()])

        columns = [
            'name', 'describe', 'images', 'resource_memory', 'resource_cpu',
            'resource_gpu'
        ]

        self.add_columns = ['project'] + columns  # 添加的时候没有挂载配置,使用项目中的挂载配置

        # 修改的时候管理员可以在上面添加一些特殊的挂载配置,适应一些特殊情况
        if g.user.is_admin():
            columns.append('volume_mount')
        self.edit_columns = ['project'] + columns
        self.edit_form_extra_fields = self.add_form_extra_fields
Esempio n. 14
0
class Hyperparameter_Tuning_ModelView_Base():
    datamodel = SQLAInterface(Hyperparameter_Tuning)
    conv = GeneralModelConverter(datamodel)
    label_title='超参搜索'
    check_redirect_list_url = '/hyperparameter_tuning_modelview/list/'
    help_url = conf.get('HELP_URL', {}).get(datamodel.obj.__tablename__, '') if datamodel else ''


    base_permissions = ['can_add', 'can_edit', 'can_delete', 'can_list', 'can_show']  # 默认为这些
    base_order = ('id', 'desc')
    base_filters = [["id", HP_Filter, lambda: []]]  # 设置权限过滤器
    order_columns = ['id']
    list_columns = ['project','name_url','describe','job_type','creator','run_url','modified']
    show_columns = ['created_by','changed_by','created_on','changed_on','job_type','name','namespace','describe',
                    'parallel_trial_count','max_trial_count','max_failed_trial_count','objective_type',
                    'objective_goal','objective_metric_name','objective_additional_metric_names','algorithm_name',
                    'algorithm_setting','parameters_html','trial_spec_html','experiment_html']


    add_form_query_rel_fields = {
        "project": [["name", Project_Filter, 'org']]
    }
    edit_form_query_rel_fields = add_form_query_rel_fields
    edit_form_extra_fields={}

    edit_form_extra_fields["alert_status"] = MySelectMultipleField(
        label=_(datamodel.obj.lab('alert_status')),
        widget=Select2ManyWidget(),
        choices=[[x, x] for x in
                 ['Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Waiting', 'Terminated']],
        description="选择通知状态",
    )

    edit_form_extra_fields['name'] = StringField(
        _(datamodel.obj.lab('name')),
        description='英文名(字母、数字、- 组成),最长50个字符',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired(), Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"), Length(1, 54)]
    )
    edit_form_extra_fields['describe'] = StringField(
        _(datamodel.obj.lab('describe')),
        description='中文描述',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['namespace'] = StringField(
        _(datamodel.obj.lab('namespace')),
        description='运行命名空间',
        widget=BS3TextFieldWidget(),
        default=datamodel.obj.namespace.default.arg,
        validators=[DataRequired()]
    )

    edit_form_extra_fields['parallel_trial_count'] = IntegerField(
        _(datamodel.obj.lab('parallel_trial_count')),
        default=datamodel.obj.parallel_trial_count.default.arg,
        description='可并行的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['max_trial_count'] = IntegerField(
        _(datamodel.obj.lab('max_trial_count')),
        default=datamodel.obj.max_trial_count.default.arg,
        description='最大并行的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['max_failed_trial_count'] = IntegerField(
        _(datamodel.obj.lab('max_failed_trial_count')),
        default=datamodel.obj.max_failed_trial_count.default.arg,
        description='最大失败的计算实例数目',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['objective_type'] = SelectField(
        _(datamodel.obj.lab('objective_type')),
        default=datamodel.obj.objective_type.default.arg,
        description='目标函数类型(和自己代码中对应)',
        widget=Select2Widget(),
        choices=[['maximize', 'maximize'], ['minimize', 'minimize']],
        validators=[DataRequired()]
    )

    edit_form_extra_fields['objective_goal'] = FloatField(
        _(datamodel.obj.lab('objective_goal')),
        default=datamodel.obj.objective_goal.default.arg,
        description='目标门限',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['objective_metric_name'] = StringField(
        _(datamodel.obj.lab('objective_metric_name')),
        default=datamodel.obj.objective_metric_name.default.arg,
        description='目标函数(和自己代码中对应)',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['objective_additional_metric_names'] = StringField(
        _(datamodel.obj.lab('objective_additional_metric_names')),
        default=datamodel.obj.objective_additional_metric_names.default.arg,
        description='其他目标函数(和自己代码中对应)',
        widget=BS3TextFieldWidget()
    )
    algorithm_name_choices = ['grid', 'random', 'hyperband', 'bayesianoptimization']
    algorithm_name_choices = [[algorithm_name_choice, algorithm_name_choice] for algorithm_name_choice in
                              algorithm_name_choices]
    edit_form_extra_fields['algorithm_name'] = SelectField(
        _(datamodel.obj.lab('algorithm_name')),
        default=datamodel.obj.algorithm_name.default.arg,
        description='搜索算法',
        widget=Select2Widget(),
        choices=algorithm_name_choices,
        validators=[DataRequired()]
    )
    edit_form_extra_fields['algorithm_setting'] = StringField(
        _(datamodel.obj.lab('algorithm_setting')),
        default=datamodel.obj.algorithm_setting.default.arg,
        widget=BS3TextFieldWidget(),
        description='搜索算法配置'
    )

    edit_form_extra_fields['parameters_demo'] = StringField(
        _(datamodel.obj.lab('parameters_demo')),
        description='搜索参数示例,标准json格式,注意:所有整型、浮点型都写成字符串型',
        widget=MyCodeArea(code=core.hp_parameters_demo()),
    )
    edit_form_extra_fields['parameters'] = StringField(
        _(datamodel.obj.lab('parameters')),
        default=datamodel.obj.parameters.default.arg,
        description='搜索参数,注意:所有整型、浮点型都写成字符串型',
        widget=MyBS3TextAreaFieldWidget(rows=10),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['node_selector'] = StringField(
        _(datamodel.obj.lab('node_selector')),
        description="部署task所在的机器(目前无需填写)",
        widget=BS3TextFieldWidget()
    )
    edit_form_extra_fields['working_dir'] = StringField(
        _(datamodel.obj.lab('working_dir')),
        description="工作目录,如果为空,则使用Dockerfile中定义的workingdir",
        widget=BS3TextFieldWidget()
    )
    edit_form_extra_fields['image_pull_policy'] = SelectField(
        _(datamodel.obj.lab('image_pull_policy')),
        description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)",
        widget=Select2Widget(),
        choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']]
    )
    edit_form_extra_fields['volume_mount'] = StringField(
        _(datamodel.obj.lab('volume_mount')),
        description='外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2',
        widget=BS3TextFieldWidget()
    )
    edit_form_extra_fields['resource_memory'] = StringField(
        _(datamodel.obj.lab('resource_memory')),
        default=datamodel.obj.resource_memory.default.arg,
        description='内存的资源使用限制(每个测试实例),示例:1G,20G',
        widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )
    edit_form_extra_fields['resource_cpu'] = StringField(
        _(datamodel.obj.lab('resource_cpu')),
        default=datamodel.obj.resource_cpu.default.arg,
        description='cpu的资源使用限制(每个测试实例)(单位:核),示例:2', widget=BS3TextFieldWidget(),
        validators=[DataRequired()]
    )


    # @pysnooper.snoop()
    def set_column(self, hp=None):
        # 对编辑进行处理
        request_data = request.args.to_dict()
        job_type = request_data.get('job_type', '')
        if hp:
            job_type = hp.job_type

        job_type_choices = ['','TFJob','XGBoostJob','PyTorchJob','Job']
        job_type_choices = [[job_type_choice,job_type_choice] for job_type_choice in job_type_choices]

        if hp:
            self.edit_form_extra_fields['job_type'] = SelectField(
                _(self.datamodel.obj.lab('job_type')),
                description="超参搜索的任务类型",
                choices=job_type_choices,
                widget=MySelect2Widget(extra_classes="readonly",value=job_type),
                validators=[DataRequired()]
            )
        else:
            self.edit_form_extra_fields['job_type'] = SelectField(
                _(self.datamodel.obj.lab('job_type')),
                description="超参搜索的任务类型",
                widget=MySelect2Widget(new_web=True,value=job_type),
                choices=job_type_choices,
                validators=[DataRequired()]
            )


        self.edit_form_extra_fields['tf_worker_num'] = IntegerField(
            _(self.datamodel.obj.lab('tf_worker_num')),
            default=json.loads(hp.job_json).get('tf_worker_num',3) if hp and hp.job_json else 3,
            description='工作节点数目',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['tf_worker_image'] = StringField(
            _(self.datamodel.obj.lab('tf_worker_image')),
            default=json.loads(hp.job_json).get('tf_worker_image',conf.get('KATIB_TFJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_TFJOB_DEFAULT_IMAGE',''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['tf_worker_command'] = StringField(
            _(self.datamodel.obj.lab('tf_worker_command')),
            default=json.loads(hp.job_json).get('tf_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['job_worker_image'] = StringField(
            _(self.datamodel.obj.lab('job_worker_image')),
            default=json.loads(hp.job_json).get('job_worker_image',conf.get('KATIB_JOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_JOB_DEFAULT_IMAGE',''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['job_worker_command'] = StringField(
            _(self.datamodel.obj.lab('job_worker_command')),
            default=json.loads(hp.job_json).get('job_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=MyBS3TextAreaFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['pytorch_worker_num'] = IntegerField(
            _(self.datamodel.obj.lab('pytorch_worker_num')),
            default=json.loads(hp.job_json).get('pytorch_worker_num', 3) if hp and hp.job_json else 3,
            description='工作节点数目',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['pytorch_worker_image'] = StringField(
            _(self.datamodel.obj.lab('pytorch_worker_image')),
            default=json.loads(hp.job_json).get('pytorch_worker_image',conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE',''),
            description='工作节点镜像',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['pytorch_master_command'] = StringField(
            _(self.datamodel.obj.lab('pytorch_master_command')),
            default=json.loads(hp.job_json).get('pytorch_master_command',
                                                'python xx.py') if hp and hp.job_json else 'python xx.py',
            description='master节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )
        self.edit_form_extra_fields['pytorch_worker_command'] = StringField(
            _(self.datamodel.obj.lab('pytorch_worker_command')),
            default=json.loads(hp.job_json).get('pytorch_worker_command',
                                                'python xx.py') if hp and hp.job_json else 'python xx.py',
            description='工作节点启动命令',
            widget=BS3TextFieldWidget(),
            validators=[DataRequired()]
        )

        self.edit_columns = ['job_type','project','name','namespace','describe','parallel_trial_count','max_trial_count','max_failed_trial_count',
                          'objective_type','objective_goal','objective_metric_name','objective_additional_metric_names',
                          'algorithm_name','algorithm_setting','parameters_demo',
                          'parameters']
        self.edit_fieldsets=[(
            lazy_gettext('common'),
            {"fields":  copy.deepcopy(self.edit_columns), "expanded": True},
        )]

        if job_type=='TFJob':
            group_columns = ['tf_worker_num','tf_worker_image','tf_worker_command']
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {"fields":group_columns, "expanded": True},
            )
            )
            for column in group_columns:
                self.edit_columns.append(column)
        if job_type=='Job':
            group_columns = ['job_worker_image','job_worker_command']
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {"fields":group_columns, "expanded": True},
            )
            )
            for column in group_columns:
                self.edit_columns.append(column)
        if job_type=='PyTorchJob':
            group_columns = ['pytorch_worker_num','pytorch_worker_image','pytorch_master_command','pytorch_worker_command']
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {"fields":group_columns, "expanded": True},
            )
            )
            for column in group_columns:
                self.edit_columns.append(column)

        if job_type=='XGBoostJob':
            group_columns = ['pytorchjob_worker_image','pytorchjob_worker_command']
            self.edit_fieldsets.append((
                lazy_gettext(job_type),
                {"fields":group_columns, "expanded": True},
            )
            )
            for column in group_columns:
                self.edit_columns.append(column)


        task_column=['working_dir','volume_mount','node_selector','image_pull_policy','resource_memory','resource_cpu']
        self.edit_fieldsets.append((
            lazy_gettext('task args'),
            {"fields": task_column, "expanded": True},
        ))
        for column in task_column:
            self.edit_columns.append(column)


        self.edit_fieldsets.append((
            lazy_gettext('run experiment'),
            {"fields": ['alert_status'], "expanded": True},
        ))


        self.edit_columns.append('alert_status')

        self.add_form_extra_fields = self.edit_form_extra_fields
        self.add_fieldsets = self.edit_fieldsets
        self.add_columns=self.edit_columns


    # 处理form请求
    def process_form(self, form, is_created):
        # from flask_appbuilder.forms import DynamicForm
        if 'parameters_demo' in form._fields:
            del form._fields['parameters_demo']  # 不处理这个字段

    # 生成实验
    # @pysnooper.snoop()
    def make_experiment(self,item):

        # 搜索算法相关
        algorithmsettings = []
        for setting in item.algorithm_setting.strip().split(','):
            setting = setting.strip()
            if setting:
                key,value = setting.split('=')[0].strip(),setting.split('=')[1].strip()
                algorithmsettings.append(V1alpha3AlgorithmSetting(name=key,value=value))

        algorithm = V1alpha3AlgorithmSpec(
            algorithm_name=item.algorithm_name,
            algorithm_settings=algorithmsettings if algorithmsettings else None
        )

        # 实验结果度量,很多中搜集方式,这里不应该写死这个。
        metrics_collector_spec=None
        if item.job_type=='TFJob':
            collector = V1alpha3CollectorSpec(kind="TensorFlowEvent")
            source = V1alpha3SourceSpec(V1alpha3FileSystemPath(kind="Directory", path="/train"))
            metrics_collector_spec = V1alpha3MetricsCollectorSpec(
                collector=collector,
                source=source)
        elif item.job_type=='Job':
            pass


        # 目标函数
        objective = V1alpha3ObjectiveSpec(
            goal=item.objective_goal,
            objective_metric_name=item.objective_metric_name,
            type=item.objective_type)

        # 搜索参数
        parameters=[]
        hp_parameters = json.loads(item.parameters)
        for parameter in hp_parameters:
            if hp_parameters[parameter]['type']=='int' or hp_parameters[parameter]['type']=='double':
                feasible_space = V1alpha3FeasibleSpace(
                    min=str(hp_parameters[parameter]['min']),
                    max=str(hp_parameters[parameter]['max']),
                    step = str(hp_parameters[parameter].get('step','')) if hp_parameters[parameter].get('step','') else None)
                parameters.append(V1alpha3ParameterSpec(
                    feasible_space=feasible_space,
                    name=parameter,
                    parameter_type=hp_parameters[parameter]['type']
                ))
            elif hp_parameters[parameter]['type']=='categorical':
                feasible_space = V1alpha3FeasibleSpace(list=hp_parameters[parameter]['list'])
                parameters.append(V1alpha3ParameterSpec(
                    feasible_space=feasible_space,
                    name=parameter,
                    parameter_type=hp_parameters[parameter]['type']
                ))


        # 实验模板
        go_template = V1alpha3GoTemplate(
            raw_template=item.trial_spec
        )

        trial_template = V1alpha3TrialTemplate(go_template=go_template)
        labels = {
            "run-rtx":g.user.username,
            "hp-name":item.name,
            # "hp-describe": item.describe
        }
        # Experiment 跑实例测试
        experiment = V1alpha3Experiment(
            api_version= conf.get('CRD_INFO')['experiment']['group']+"/"+ conf.get('CRD_INFO')['experiment']['version'] ,#"kubeflow.org/v1alpha3",
            kind="Experiment",
            metadata=V1ObjectMeta(name=item.name+"-"+uuid.uuid4().hex[:4], namespace=conf.get('KATIB_NAMESPACE'),labels=labels),

            spec=V1alpha3ExperimentSpec(
                algorithm=algorithm,
                max_failed_trial_count=item.max_failed_trial_count,
                max_trial_count=item.max_trial_count,
                metrics_collector_spec=metrics_collector_spec,
                objective=objective,
                parallel_trial_count=item.parallel_trial_count,
                parameters=parameters,
                trial_template=trial_template
            )
        )
        item.experiment = json.dumps(experiment.to_dict(),indent=4,ensure_ascii=False)

    @expose('/create_experiment/<id>',methods=['GET'])
    # @pysnooper.snoop(watch_explode=('hp',))
    def create_experiment(self,id):
        hp = db.session.query(Hyperparameter_Tuning).filter(Hyperparameter_Tuning.id == int(id)).first()
        if hp:
            from myapp.utils.py.py_k8s import K8s
            k8s_client = K8s(hp.project.cluster.get('KUBECONFIG',''))
            namespace = conf.get('KATIB_NAMESPACE')
            crd_info =conf.get('CRD_INFO')['experiment']
            print(hp.experiment)
            k8s_client.create_crd(group=crd_info['group'],version=crd_info['version'],plural=crd_info['plural'],namespace=namespace,body=hp.experiment)
            flash('部署完成','success')

            # kclient = kc.KatibClient()
            # kclient.create_experiment(hp, namespace=conf.get('KATIB_NAMESPACE'))

        self.update_redirect()
        return redirect(self.get_redirect())



    # @pysnooper.snoop(watch_explode=())
    def merge_trial_spec(self,item):

        image_secrets = conf.get('HUBSECRET',[])
        user_hubsecrets = db.session.query(Repository.hubsecret).filter(Repository.created_by_fk == g.user.id).all()
        if user_hubsecrets:
            for hubsecret in user_hubsecrets:
                if hubsecret[0] not in image_secrets:
                    image_secrets.append(hubsecret[0])

        image_secrets = [
            {
                "name": hubsecret
            } for hubsecret in image_secrets
        ]

        item.job_json={}
        if item.job_type=='TFJob':
            item.trial_spec=core.merge_tfjob_experiment_template(
                worker_num=item.tf_worker_num,
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.tf_worker_image,
                image_secrets = image_secrets,
                hostAliases=conf.get('HOSTALIASES',''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                command=item.tf_worker_command
            )
            item.job_json={
                "tf_worker_num":item.tf_worker_num,
                "tf_worker_image": item.tf_worker_image,
                "tf_worker_command": item.tf_worker_command,
            }
        if item.job_type == 'Job':
            item.trial_spec=core.merge_job_experiment_template(
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.job_worker_image,
                image_secrets=image_secrets,
                hostAliases=conf.get('HOSTALIASES',''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                command=item.job_worker_command
            )

            item.job_json = {
                "job_worker_image": item.job_worker_image,
                "job_worker_command": item.job_worker_command,
            }
        if item.job_type == 'PyTorchJob':
            item.trial_spec=core.merge_pytorchjob_experiment_template(
                worker_num=item.pytorch_worker_num,
                node_selector=item.get_node_selector(),
                volume_mount=item.volume_mount,
                image=item.pytorch_worker_image,
                image_secrets=image_secrets,
                hostAliases=conf.get('HOSTALIASES', ''),
                workingDir=item.working_dir,
                image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'),
                resource_memory=item.resource_memory,
                resource_cpu=item.resource_cpu,
                master_command=item.pytorch_master_command,
                worker_command=item.pytorch_worker_command
              )

            item.job_json = {
                "pytorch_worker_num":item.pytorch_worker_num,
                "pytorch_worker_image": item.pytorch_worker_image,
                "pytorch_master_command": item.pytorch_master_command,
                "pytorch_worker_command": item.pytorch_worker_command,
            }
        item.job_json = json.dumps(item.job_json,indent=4,ensure_ascii=False)


    # 检验参数是否有效
    # @pysnooper.snoop()
    def validate_parameters(self,parameters,algorithm):
        try:
            parameters = json.loads(parameters)
            for parameter_name in parameters:
                parameter = parameters[parameter_name]
                if parameter['type'] == 'int' and 'min' in parameter and 'max' in parameter:
                    parameter['min'] = int(parameter['min'])
                    parameter['max'] = int(parameter['max'])
                    if not parameter['max']>parameter['min']:
                        raise Exception('min must lower than max')
                    continue
                if parameter['type'] == 'double' and 'min' in parameter and 'max' in parameter:
                    parameter['min'] = float(parameter['min'])
                    parameter['max'] = float(parameter['max'])
                    if not parameter['max']>parameter['min']:
                        raise Exception('min must lower than max')
                    if algorithm=='grid':
                        parameter['step'] = float(parameter['step'])
                    continue
                if parameter['type']=='categorical' and 'list' in parameter and type(parameter['list'])==list:
                    continue

                raise MyappException('parameters type must in [int,double,categorical], and min\max\step\list should exist, and min must lower than max ')

            return json.dumps(parameters,indent=4,ensure_ascii=False)

        except Exception as e:
            print(e)
            raise MyappException('parameters not valid:'+str(e))


    # @pysnooper.snoop()
    def pre_add(self, item):
        if item.job_type is None:
            raise MyappException("Job type is mandatory")

        core.validate_json(item.parameters)
        item.parameters = self.validate_parameters(item.parameters,item.algorithm_name)

        item.resource_memory=core.check_resource_memory(item.resource_memory,self.src_item_json.get('resource_memory',None) if self.src_item_json else None)
        item.resource_cpu = core.check_resource_cpu(item.resource_cpu,self.src_item_json.get('resource_cpu',None) if self.src_item_json else None)
        self.merge_trial_spec(item)
        self.make_experiment(item)


    def pre_update(self, item):
        self.pre_add(item)

    pre_add_get=set_column
    pre_update_get=set_column


    @action(
        "copy", __("Copy Hyperparameter Experiment"), confirmation=__('Copy Hyperparameter Experiment'), icon="fa-copy",multiple=True, single=False
    )
    def copy(self, hps):
        if not isinstance(hps, list):
            hps = [hps]
        for hp in hps:
            new_hp = hp.clone()
            new_hp.name = new_hp.name+"-copy"
            new_hp.describe = new_hp.describe + "-copy"
            new_hp.created_on = datetime.datetime.now()
            new_hp.changed_on = datetime.datetime.now()
            db.session.add(new_hp)
            db.session.commit()

        return redirect(request.referrer)
Esempio n. 15
0
class TableModelView(  # pylint: disable=too-many-ancestors
        DatasourceModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.SqlaTable)
    class_permission_name = "Dataset"
    method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
    include_route_methods = RouteMethod.CRUD_SET

    list_title = _("Tables")
    show_title = _("Show Table")
    add_title = _("Import a table definition")
    edit_title = _("Edit Table")

    list_columns = ["link", "database_name", "changed_by_", "modified"]
    order_columns = ["modified"]
    add_columns = ["database", "schema", "table_name"]
    edit_columns = [
        "table_name",
        "sql",
        "filter_select_enabled",
        "fetch_values_predicate",
        "database",
        "schema",
        "description",
        "owners",
        "main_dttm_col",
        "default_endpoint",
        "offset",
        "cache_timeout",
        "is_sqllab_view",
        "template_params",
        "extra",
    ]
    base_filters = [["id", DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ["perm", "slices"]
    related_views = [
        TableColumnInlineView,
        SqlMetricInlineView,
    ]
    base_order = ("changed_on", "desc")
    search_columns = ("database", "schema", "table_name", "owners",
                      "is_sqllab_view")
    description_columns = {
        "slices":
        _("The list of charts associated with this table. By "
          "altering this datasource, you may change how these associated "
          "charts behave. "
          "Also note that charts need to point to a datasource, so "
          "this form will fail at saving if removing charts from a "
          "datasource. If you want to change the datasource for a chart, "
          "overwrite the chart from the 'explore view'"),
        "offset":
        _("Timezone offset (in hours) for this datasource"),
        "table_name":
        _("Name of the table that exists in the source database"),
        "schema":
        _("Schema, as used only in some databases like Postgres, Redshift "
          "and DB2"),
        "description":
        Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            "markdown</a>"),
        "sql":
        _("This fields acts a Superset view, meaning that Superset will "
          "run a query against this string as a subquery."),
        "fetch_values_predicate":
        _("Predicate applied when fetching distinct value to "
          "populate the filter control component. Supports "
          "jinja template syntax. Applies only when "
          "`Enable Filter Select` is on."),
        "default_endpoint":
        _("Redirects to this endpoint when clicking on the table "
          "from the table list"),
        "filter_select_enabled":
        _("Whether to populate the filter's dropdown in the explore "
          "view's filter section with a list of distinct values fetched "
          "from the backend on the fly"),
        "is_sqllab_view":
        _("Whether the table was generated by the 'Visualize' flow "
          "in SQL Lab"),
        "template_params":
        _("A set of parameters that become available in the query using "
          "Jinja templating syntax"),
        "cache_timeout":
        _("Duration (in seconds) of the caching timeout for this table. "
          "A timeout of 0 indicates that the cache never expires. "
          "Note this defaults to the database timeout if undefined."),
        "extra":
        utils.markdown(
            "Extra data to specify table metadata. Currently supports "
            'metadata of the format: `{ "certification": { "certified_by": '
            '"Data Platform Team", "details": "This table is the source of truth." '
            '}, "warning_markdown": "This is a warning." }`.',
            True,
        ),
    }
    label_columns = {
        "slices": _("Associated Charts"),
        "link": _("Table"),
        "changed_by_": _("Changed By"),
        "database": _("Database"),
        "database_name": _("Database"),
        "changed_on_": _("Last Changed"),
        "filter_select_enabled": _("Enable Filter Select"),
        "schema": _("Schema"),
        "default_endpoint": _("Default Endpoint"),
        "offset": _("Offset"),
        "cache_timeout": _("Cache Timeout"),
        "table_name": _("Table Name"),
        "fetch_values_predicate": _("Fetch Values Predicate"),
        "owners": _("Owners"),
        "main_dttm_col": _("Main Datetime Column"),
        "description": _("Description"),
        "is_sqllab_view": _("SQL Lab View"),
        "template_params": _("Template parameters"),
        "extra": _("Extra"),
        "modified": _("Modified"),
    }
    edit_form_extra_fields = {
        "database":
        QuerySelectField(
            "Database",
            query_factory=lambda: db.session.query(models.Database),
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    def post_add(  # pylint: disable=arguments-differ
        self,
        item: "TableModelView",
        flash_message: bool = True,
        fetch_metadata: bool = True,
    ) -> None:
        if fetch_metadata:
            item.fetch_metadata()
        create_table_permissions(item)
        if flash_message:
            flash(
                _("The table was created. "
                  "As part of this two-phase configuration "
                  "process, you should now click the edit button by "
                  "the new table to configure it."),
                "info",
            )

    def post_update(self, item: "TableModelView") -> None:
        self.post_add(item, flash_message=False, fetch_metadata=False)

    def _delete(self, pk: int) -> None:
        DeleteMixin._delete(self, pk)

    @expose("/edit/<pk>", methods=["GET", "POST"])
    @has_access
    def edit(self, pk: str) -> FlaskResponse:
        """Simple hack to redirect to explore view after saving"""
        resp = super().edit(pk)
        if isinstance(resp, str):
            return resp
        return redirect("/superset/explore/table/{}/".format(pk))

    @expose("/list/")
    @has_access
    def list(self) -> FlaskResponse:
        return super().render_app_template()
Esempio n. 16
0
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):  # noqa
    datamodel = SQLAInterface(models.TableColumn)

    list_title = _("Columns")
    show_title = _("Show Column")
    add_title = _("Add Column")
    edit_title = _("Edit Column")

    can_delete = False
    list_widget = ListWidgetWithCheckboxes
    edit_columns = [
        "column_name",
        "verbose_name",
        "description",
        "type",
        "groupby",
        "filterable",
        "table",
        "expression",
        "is_dttm",
        "python_date_format",
        "database_expression",
    ]
    add_columns = edit_columns
    list_columns = [
        "column_name",
        "verbose_name",
        "type",
        "groupby",
        "filterable",
        "is_dttm",
    ]
    page_size = 500
    description_columns = {
        "is_dttm":
        _("Whether to make this column available as a "
          "[Time Granularity] option, column has to be DATETIME or "
          "DATETIME-like"),
        "filterable":
        _("Whether this column is exposed in the `Filters` section "
          "of the explore view."),
        "type":
        _("The data type that was inferred by the database. "
          "It may be necessary to input a type manually for "
          "expression-defined columns in some cases. In most case "
          "users should not need to alter this."),
        "expression":
        utils.markdown(
            "a valid, *non-aggregating* SQL expression as supported by the "
            "underlying backend. Example: `substr(name, 1, 1)`",
            True,
        ),
        "python_date_format":
        utils.markdown(
            Markup(
                "The pattern of timestamp format, use "
                '<a href="https://docs.python.org/2/library/'
                'datetime.html#strftime-strptime-behavior">'
                "python datetime string pattern</a> "
                "expression. If time is stored in epoch "
                "format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` "
                "below empty if timestamp is stored in "
                "String or Integer(epoch) type"),
            True,
        ),
        "database_expression":
        utils.markdown(
            "The database expression to cast internal datetime "
            "constants to database date/timestamp type according to the DBAPI. "
            "The expression should follow the pattern of "
            "%Y-%m-%d %H:%M:%S, based on different DBAPI. "
            "The string should be a python string formatter \n"
            "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle "
            "Superset uses default expression based on DB URI if this "
            "field is blank.",
            True,
        ),
    }
    label_columns = {
        "column_name": _("Column"),
        "verbose_name": _("Verbose Name"),
        "description": _("Description"),
        "groupby": _("Groupable"),
        "filterable": _("Filterable"),
        "table": _("Table"),
        "expression": _("Expression"),
        "is_dttm": _("Is temporal"),
        "python_date_format": _("Datetime Format"),
        "database_expression": _("Database Expression"),
        "type": _("Type"),
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Esempio n. 17
0
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.TableColumn)
    # TODO TODO, review need for this on related_views
    class_permission_name = "Dataset"
    method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
    include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET

    list_title = _("Columns")
    show_title = _("Show Column")
    add_title = _("Add Column")
    edit_title = _("Edit Column")

    can_delete = False
    list_widget = ListWidgetWithCheckboxes
    edit_columns = [
        "column_name",
        "verbose_name",
        "description",
        "type",
        "groupby",
        "filterable",
        "table",
        "expression",
        "is_dttm",
        "python_date_format",
        "extra",
    ]
    add_columns = edit_columns
    list_columns = [
        "column_name",
        "verbose_name",
        "type",
        "groupby",
        "filterable",
        "is_dttm",
    ]
    page_size = 500
    description_columns = {
        "is_dttm":
        _("Whether to make this column available as a "
          "[Time Granularity] option, column has to be DATETIME or "
          "DATETIME-like"),
        "filterable":
        _("Whether this column is exposed in the `Filters` section "
          "of the explore view."),
        "type":
        _("The data type that was inferred by the database. "
          "It may be necessary to input a type manually for "
          "expression-defined columns in some cases. In most case "
          "users should not need to alter this."),
        "expression":
        utils.markdown(
            "a valid, *non-aggregating* SQL expression as supported by the "
            "underlying backend. Example: `substr(name, 1, 1)`",
            True,
        ),
        "python_date_format":
        utils.markdown(
            Markup(
                "The pattern of timestamp format. For strings use "
                '<a href="https://docs.python.org/2/library/'
                'datetime.html#strftime-strptime-behavior">'
                "python datetime string pattern</a> expression which needs to "
                'adhere to the <a href="https://en.wikipedia.org/wiki/ISO_8601">'
                "ISO 8601</a> standard to ensure that the lexicographical ordering "
                "coincides with the chronological ordering. If the timestamp "
                "format does not adhere to the ISO 8601 standard you will need to "
                "define an expression and type for transforming the string into a "
                "date or timestamp. Note currently time zones are not supported. "
                "If time is stored in epoch format, put `epoch_s` or `epoch_ms`."
                "If no pattern is specified we fall back to using the optional "
                "defaults on a per database/column name level via the extra parameter."
                ""),
            True,
        ),
        "extra":
        utils.markdown(
            "Extra data to specify column metadata. Currently supports "
            'certification data of the format: `{ "certification": "certified_by": '
            '"Taylor Swift", "details": "This column is the source of truth." '
            "} }`. This should be modified from the edit datasource model in "
            "Explore to ensure correct formatting.",
            True,
        ),
    }
    label_columns = {
        "column_name": _("Column"),
        "verbose_name": _("Verbose Name"),
        "description": _("Description"),
        "groupby": _("Groupable"),
        "filterable": _("Filterable"),
        "table": _("Table"),
        "expression": _("Expression"),
        "is_dttm": _("Is temporal"),
        "python_date_format": _("Datetime Format"),
        "type": _("Type"),
    }
    validators_columns = {
        "python_date_format": [
            # Restrict viable values to epoch_s, epoch_ms, or a strftime format
            # which adhere's to the ISO 8601 format (without time zone).
            Regexp(
                re.compile(
                    r"""
                    ^(
                        epoch_s|epoch_ms|
                        (?P<date>%Y(-%m(-%d)?)?)([\sT](?P<time>%H(:%M(:%S(\.%f)?)?)?))?
                    )$
                    """,
                    re.VERBOSE,
                ),
                message=_("Invalid date/timestamp format"),
            )
        ]
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session.query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Esempio n. 18
0
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.SqlMetric)
    include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET

    list_title = _("Metrics")
    show_title = _("Show Metric")
    add_title = _("Add Metric")
    edit_title = _("Edit Metric")

    list_columns = ["metric_name", "verbose_name", "metric_type"]
    edit_columns = [
        "metric_name",
        "description",
        "verbose_name",
        "metric_type",
        "expression",
        "table",
        "d3format",
        "warning_text",
    ]
    description_columns = {
        "expression":
        utils.markdown(
            "a valid, *aggregating* SQL expression as supported by the "
            "underlying backend. Example: `count(DISTINCT userid)`",
            True,
        ),
        "d3format":
        utils.markdown(
            "d3 formatting string as defined [here]"
            "(https://github.com/d3/d3-format/blob/master/README.md#format). "
            "For instance, this default formatting applies in the Table "
            "visualization and allow for different metric to use different "
            "formats",
            True,
        ),
    }
    add_columns = edit_columns
    page_size = 500
    label_columns = {
        "metric_name": _("Metric"),
        "description": _("Description"),
        "verbose_name": _("Verbose Name"),
        "metric_type": _("Type"),
        "expression": _("SQL Expression"),
        "table": _("Table"),
        "d3format": _("D3 Format"),
        "warning_text": _("Warning Message"),
    }

    add_form_extra_fields = {
        "table":
        QuerySelectField(
            "Table",
            query_factory=lambda: db.session().query(models.SqlaTable),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields
Esempio n. 19
0
class ConnectionForm(DynamicForm):
    """Form for editing and adding Connection"""

    conn_id = StringField(lazy_gettext('Conn Id'), widget=BS3TextFieldWidget())
    conn_type = SelectField(
        lazy_gettext('Conn Type'),
        choices=sorted(_connection_types, key=itemgetter(1)),  # pylint: disable=protected-access
        widget=Select2Widget())
    host = StringField(lazy_gettext('Host'), widget=BS3TextFieldWidget())
    schema = StringField(lazy_gettext('Schema'), widget=BS3TextFieldWidget())
    login = StringField(lazy_gettext('Login'), widget=BS3TextFieldWidget())
    password = PasswordField(lazy_gettext('Password'),
                             widget=BS3PasswordFieldWidget())
    port = IntegerField(lazy_gettext('Port'),
                        validators=[Optional()],
                        widget=BS3TextFieldWidget())
    extra = TextAreaField(lazy_gettext('Extra'),
                          widget=BS3TextAreaFieldWidget())

    # Used to customized the form, the forms elements get rendered
    # and results are stored in the extra field as json. All of these
    # need to be prefixed with extra__ and then the conn_type ___ as in
    # extra__{conn_type}__name. You can also hide form elements and rename
    # others from the connection_form.js file
    extra__jdbc__drv_path = StringField(lazy_gettext('Driver Path'),
                                        widget=BS3TextFieldWidget())
    extra__jdbc__drv_clsname = StringField(lazy_gettext('Driver Class'),
                                           widget=BS3TextFieldWidget())
    extra__google_cloud_platform__project = StringField(
        lazy_gettext('Project Id'), widget=BS3TextFieldWidget())
    extra__google_cloud_platform__key_path = StringField(
        lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget())
    extra__google_cloud_platform__keyfile_dict = PasswordField(
        lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget())
    extra__google_cloud_platform__scope = StringField(
        lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget())
    extra__google_cloud_platform__num_retries = IntegerField(
        lazy_gettext('Number of Retries'),
        validators=[NumberRange(min=0)],
        widget=BS3TextFieldWidget(),
        default=5)
    extra__grpc__auth_type = StringField(lazy_gettext('Grpc Auth Type'),
                                         widget=BS3TextFieldWidget())
    extra__grpc__credential_pem_file = StringField(
        lazy_gettext('Credential Keyfile Path'), widget=BS3TextFieldWidget())
    extra__grpc__scopes = StringField(lazy_gettext('Scopes (comma separated)'),
                                      widget=BS3TextFieldWidget())
    extra__yandexcloud__service_account_json = PasswordField(
        lazy_gettext('Service account auth JSON'),
        widget=BS3PasswordFieldWidget(),
        description='Service account auth JSON. Looks like '
        '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
        'Will be used instead of OAuth token and SA JSON file path field if specified.',
    )
    extra__yandexcloud__service_account_json_path = StringField(
        lazy_gettext('Service account auth JSON file path'),
        widget=BS3TextFieldWidget(),
        description=
        'Service account auth JSON file path. File content looks like '
        '{"id", "...", "service_account_id": "...", "private_key": "..."}. '
        'Will be used instead of OAuth token if specified.',
    )
    extra__yandexcloud__oauth = PasswordField(
        lazy_gettext('OAuth Token'),
        widget=BS3PasswordFieldWidget(),
        description=
        'User account OAuth token. Either this or service account JSON must be specified.',
    )
    extra__yandexcloud__folder_id = StringField(
        lazy_gettext('Default folder ID'),
        widget=BS3TextFieldWidget(),
        description=
        'Optional. This folder will be used to create all new clusters and nodes by default',
    )
    extra__yandexcloud__public_ssh_key = StringField(
        lazy_gettext('Public SSH key'),
        widget=BS3TextFieldWidget(),
        description=
        'Optional. This key will be placed to all created Compute nodes'
        'to let you have a root shell there',
    )
    extra__kubernetes__in_cluster = BooleanField(
        lazy_gettext('In cluster configuration'))
    extra__kubernetes__kube_config_path = StringField(
        lazy_gettext('Kube config path'), widget=BS3TextFieldWidget())
    extra__kubernetes__kube_config = StringField(
        lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget())
    extra__kubernetes__namespace = StringField(lazy_gettext('Namespace'),
                                               widget=BS3TextFieldWidget())
Esempio n. 20
0
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin):
    datamodel = SQLAInterface(models.SqlaTable)
    include_route_methods = RouteMethod.CRUD_SET

    list_title = _("Tables")
    show_title = _("Show Table")
    add_title = _("Import a table definition")
    edit_title = _("Edit Table")

    list_columns = ["link", "database_name", "changed_by_", "modified"]
    order_columns = ["modified"]
    add_columns = ["database", "schema", "table_name"]
    edit_columns = [
        "table_name",
        "sql",
        "filter_select_enabled",
        "fetch_values_predicate",
        "database",
        "schema",
        "description",
        "owners",
        "main_dttm_col",
        "default_endpoint",
        "offset",
        "cache_timeout",
        "is_sqllab_view",
        "template_params",
    ]
    base_filters = [["id", DatasourceFilter, lambda: []]]
    show_columns = edit_columns + ["perm", "slices"]
    related_views = [
        TableColumnInlineView,
        SqlMetricInlineView,
        RowLevelSecurityFiltersModelView,
    ]
    base_order = ("changed_on", "desc")
    search_columns = ("database", "schema", "table_name", "owners",
                      "is_sqllab_view")
    description_columns = {
        "slices":
        _("The list of charts associated with this table. By "
          "altering this datasource, you may change how these associated "
          "charts behave. "
          "Also note that charts need to point to a datasource, so "
          "this form will fail at saving if removing charts from a "
          "datasource. If you want to change the datasource for a chart, "
          "overwrite the chart from the 'explore view'"),
        "offset":
        _("Timezone offset (in hours) for this datasource"),
        "table_name":
        _("Name of the table that exists in the source database"),
        "schema":
        _("Schema, as used only in some databases like Postgres, Redshift "
          "and DB2"),
        "description":
        Markup(
            'Supports <a href="https://daringfireball.net/projects/markdown/">'
            "markdown</a>"),
        "sql":
        _("This fields acts a Superset view, meaning that Superset will "
          "run a query against this string as a subquery."),
        "fetch_values_predicate":
        _("Predicate applied when fetching distinct value to "
          "populate the filter control component. Supports "
          "jinja template syntax. Applies only when "
          "`Enable Filter Select` is on."),
        "default_endpoint":
        _("Redirects to this endpoint when clicking on the table "
          "from the table list"),
        "filter_select_enabled":
        _("Whether to populate the filter's dropdown in the explore "
          "view's filter section with a list of distinct values fetched "
          "from the backend on the fly"),
        "is_sqllab_view":
        _("Whether the table was generated by the 'Visualize' flow "
          "in SQL Lab"),
        "template_params":
        _("A set of parameters that become available in the query using "
          "Jinja templating syntax"),
        "cache_timeout":
        _("Duration (in seconds) of the caching timeout for this table. "
          "A timeout of 0 indicates that the cache never expires. "
          "Note this defaults to the database timeout if undefined."),
    }
    label_columns = {
        "slices": _("Associated Charts"),
        "link": _("Table"),
        "changed_by_": _("Changed By"),
        "database": _("Database"),
        "database_name": _("Database"),
        "changed_on_": _("Last Changed"),
        "filter_select_enabled": _("Enable Filter Select"),
        "schema": _("Schema"),
        "default_endpoint": _("Default Endpoint"),
        "offset": _("Offset"),
        "cache_timeout": _("Cache Timeout"),
        "table_name": _("Table Name"),
        "fetch_values_predicate": _("Fetch Values Predicate"),
        "owners": _("Owners"),
        "main_dttm_col": _("Main Datetime Column"),
        "description": _("Description"),
        "is_sqllab_view": _("SQL Lab View"),
        "template_params": _("Template parameters"),
        "modified": _("Modified"),
    }
    edit_form_extra_fields = {
        "database":
        QuerySelectField(
            "Database",
            query_factory=lambda: db.session().query(models.Database),
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    def pre_add(self, table):
        with db.session.no_autoflush:
            table_query = db.session.query(models.SqlaTable).filter(
                models.SqlaTable.table_name == table.table_name,
                models.SqlaTable.schema == table.schema,
                models.SqlaTable.database_id == table.database.id,
            )
            if db.session.query(table_query.exists()).scalar():
                raise Exception(get_datasource_exist_error_msg(
                    table.full_name))

        # Fail before adding if the table can't be found
        try:
            table.get_sqla_table_object()
        except Exception as e:
            logger.exception(f"Got an error in pre_add for {table.name}")
            raise Exception(
                _("Table [{}] could not be found, "
                  "please double check your "
                  "database connection, schema, and "
                  "table name, error: {}").format(table.name, str(e)))

    def post_add(self, table, flash_message=True):
        table.fetch_metadata()
        security_manager.add_permission_view_menu("datasource_access",
                                                  table.get_perm())
        if table.schema:
            security_manager.add_permission_view_menu("schema_access",
                                                      table.schema_perm)

        if flash_message:
            flash(
                _("The table was created. "
                  "As part of this two-phase configuration "
                  "process, you should now click the edit button by "
                  "the new table to configure it."),
                "info",
            )

    def post_update(self, table):
        self.post_add(table, flash_message=False)

    def _delete(self, pk):
        DeleteMixin._delete(self, pk)

    @expose("/edit/<pk>", methods=["GET", "POST"])
    @has_access
    def edit(self, pk):
        """Simple hack to redirect to explore view after saving"""
        resp = super(TableModelView, self).edit(pk)
        if isinstance(resp, str):
            return resp
        return redirect("/superset/explore/table/{}/".format(pk))

    @action("refresh", __("Refresh Metadata"), __("Refresh column metadata"),
            "fa-refresh")
    def refresh(self, tables):
        if not isinstance(tables, list):
            tables = [tables]
        successes = []
        failures = []
        for t in tables:
            try:
                t.fetch_metadata()
                successes.append(t)
            except Exception:
                failures.append(t)

        if len(successes) > 0:
            success_msg = _(
                "Metadata refreshed for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in successes]),
            )
            flash(success_msg, "info")
        if len(failures) > 0:
            failure_msg = _(
                "Unable to retrieve metadata for the following table(s): %(tables)s",
                tables=", ".join([t.table_name for t in failures]),
            )
            flash(failure_msg, "danger")

        return redirect("/tablemodelview/list/")
Esempio n. 21
0
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView):
    datamodel = SQLAInterface(models.DruidColumn)
    include_route_methods = RouteMethod.RELATED_VIEW_SET

    list_title = _("Columns")
    show_title = _("Show Druid Column")
    add_title = _("Add Druid Column")
    edit_title = _("Edit Druid Column")

    list_widget = ListWidgetWithCheckboxes

    edit_columns = [
        "column_name",
        "verbose_name",
        "description",
        "dimension_spec_json",
        "datasource",
        "groupby",
        "filterable",
    ]
    add_columns = edit_columns
    list_columns = ["column_name", "verbose_name", "type", "groupby", "filterable"]
    can_delete = False
    page_size = 500
    label_columns = {
        "column_name": _("Column"),
        "type": _("Type"),
        "datasource": _("Datasource"),
        "groupby": _("Groupable"),
        "filterable": _("Filterable"),
    }
    description_columns = {
        "filterable": _(
            "Whether this column is exposed in the `Filters` section "
            "of the explore view."
        ),
        "dimension_spec_json": utils.markdown(
            "this field can be used to specify  "
            "a `dimensionSpec` as documented [here]"
            "(http://druid.io/docs/latest/querying/dimensionspecs.html). "
            "Make sure to input valid JSON and that the "
            "`outputName` matches the `column_name` defined "
            "above.",
            True,
        ),
    }

    add_form_extra_fields = {
        "datasource": QuerySelectField(
            "Datasource",
            query_factory=lambda: db.session.query(models.DruidDatasource),
            allow_blank=True,
            widget=Select2Widget(extra_classes="readonly"),
        )
    }

    edit_form_extra_fields = add_form_extra_fields

    def pre_update(self, item: "DruidColumnInlineView") -> None:
        # If a dimension spec JSON is given, ensure that it is
        # valid JSON and that `outputName` is specified
        if item.dimension_spec_json:
            try:
                dimension_spec = json.loads(item.dimension_spec_json)
            except ValueError as ex:
                raise ValueError("Invalid Dimension Spec JSON: " + str(ex))
            if not isinstance(dimension_spec, dict):
                raise ValueError("Dimension Spec must be a JSON object")
            if "outputName" not in dimension_spec:
                raise ValueError("Dimension Spec does not contain `outputName`")
            if "dimension" not in dimension_spec:
                raise ValueError("Dimension Spec is missing `dimension`")
            # `outputName` should be the same as the `column_name`
            if dimension_spec["outputName"] != item.column_name:
                raise ValueError(
                    "`outputName` [{}] unequal to `column_name` [{}]".format(
                        dimension_spec["outputName"], item.column_name
                    )
                )

    def post_update(self, item: "DruidColumnInlineView") -> None:
        item.refresh_metrics()

    def post_add(self, item: "DruidColumnInlineView") -> None:
        self.post_update(item)
Esempio n. 22
0
class BeaconModelView(ModelView):
    datamodel = SQLAInterface(Beacon)
    base_permissions = [
        'can_list', 'can_show', 'can_add', 'can_edit', 'can_delete',
        'can_post_beacon'
    ]

    edit_form_extra_fields = {
        'beacon_filter':
        Field('Beacon Filter',
              widget=FilterBuilderWidget(
                  beacon_filters=get_querybuilder_filters_json(),
                  beacon_fields=get_all_beacon_fields_json()),
              validators=[validators.Required()],
              description=('Only incoming Beacon packets matching this '
                           'filter will be processed')),
        'beacon_data_mapping':
        Field('Beacon Data Mapping',
              widget=BeaconFieldsWidget(packet_fields=get_all_packet_fields(),
                                        beacon_fields=get_all_beacon_fields()),
              validators=[validators.Required(),
                          BeaconDataMappingCheck()],
              description=('Extract message data based on the selected '
                           'mapping schema')),
        'response_data_type':
        SelectField(
            'Response Type',
            choices=get_all_response_types(),
            description=(
                'The protocol used for response messages to the implant'),
            widget=Select2Widget()),
        'response_data_mapping':
        Field('Response Data Mapping',
              widget=ResponseFieldsWidget(
                  packet_fields=get_all_packet_fields(),
                  response_fields=get_all_task_fields()),
              validators=[validators.Required(),
                          BeaconDataMappingCheck()],
              description=('Format response messages based on the selected '
                           'mapping schema'))
    }
    add_form_extra_fields = {
        'beacon_filter':
        Field('Beacon Filter',
              widget=FilterBuilderWidget(
                  beacon_filters=get_querybuilder_filters_json()),
              validators=[validators.Required()],
              description=('Only incoming packets matching this filter will '
                           'be processed as a Beacon')),
        'beacon_data_mapping':
        Field(
            'Beacon Data Mapping',
            widget=BeaconFieldsWidget(packet_fields=get_all_packet_fields(),
                                      beacon_fields=get_all_beacon_fields()),
            validators=[validators.Required(),
                        BeaconDataMappingCheck()],
            description=(
                'Extract message data based on the selected mapping schema')),
        'response_data_type':
        SelectField(
            'Response Type',
            choices=get_all_response_types(),
            description=(
                'The protocol used for response messages to the implant'),
            widget=Select2Widget()),
        'response_data_mapping':
        Field(
            'Response Data Mapping',
            widget=ResponseFieldsWidget(packet_fields=get_all_packet_fields(),
                                        response_fields=get_all_task_fields()),
            validators=[validators.Required(),
                        BeaconDataMappingCheck()],
            description=(
                'Format reply messages based on the selected mapping schema'))
    }
    edit_columns = [
        'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type',
        'response_data_mapping'
    ]
    add_columns = [
        'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type',
        'response_data_mapping'
    ]
    list_columns = [
        'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type',
        'response_data_mapping'
    ]

    show_fieldsets = [('Filter', {
        'fields': ['name', 'beacon_filter']
    }), ('Beacon Data Mapping', {
        'fields': ['beacon_data_mapping']
    }),
                      ('Task Data Mapping', {
                          'fields':
                          ['response_data_type', 'response_data_mapping']
                      })]

    description_columns = {'name': 'Simple name for easy reference'}

    @action("muldelete", "Delete", "Delete all Really?", "fa-trash")
    def muldelete(self, items):
        if isinstance(items, list):
            self.datamodel.delete_all(items)
            self.update_redirect()
        else:
            self.datamodel.delete(items)
        return redirect(self.get_redirect())

    @expose_api(name='post_beacon', url='/api/postbeacon', methods=['POST'])
    @has_access_api
    def post_beacon(self):
        """API used to send captured beacons from LP to Controller"""

        beacon = json_to_beacon(request.data)

        # Check if implant already exists
        implant = db.session.query(Implant).filter_by(
            uuid=beacon['uuid']).first()

        if implant:
            # Update existing implant
            implant.last_beacon_received = datetime.now()
            implant.external_ip_address = beacon['external_ip_address']
            db.session.commit()
        else:
            # Add new implant
            implant = Implant(uuid=beacon['uuid'])
            db.session.add(implant)
            db.session.commit()

        # Store beacon data
        if 'data' in beacon:
            beacon_data = beacon['data']
            if beacon_data:

                datastore = DataStore(implant=[implant],
                                      timestamp=datetime.now())

                if is_ascii(beacon_data):
                    datastore.text_received = beacon_data
                else:
                    datastore.data_received = beacon_data

                db.session.add(datastore)
                db.session.commit()

        http_return_code = 200
        response = make_response('Success', http_return_code)
        return response