def validate_int_params(int_param, param_name): """ Verify the parameter which type is integer valid or not. Args: int_param (int): parameter that is integer, including epoch, dataset_batch_size, step_num param_name (str): the name of parameter, including epoch, dataset_batch_size, step_num Raises: MindInsightException: If the parameters are invalid. """ if not isinstance(int_param, int) or int_param <= 0 or int_param > pow(2, 63) - 1: if param_name == 'step_num': log.error( 'Invalid step_num. The step number should be a positive integer.' ) raise MindInsightException( error=LineageErrors.PARAM_STEP_NUM_ERROR, message=LineageErrorMsg.PARAM_STEP_NUM_ERROR.value) if param_name == 'dataset_batch_size': log.error('Invalid dataset_batch_size. ' 'The batch size should be a positive integer.') raise MindInsightException( error=LineageErrors.PARAM_BATCH_SIZE_ERROR, message=LineageErrorMsg.PARAM_BATCH_SIZE_ERROR.value)
def check_comparision(self, data, **kwargs): """Check comparision for all parameters in schema.""" for attr, condition in data.items(): if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']: continue if not isinstance(attr, str): raise LineageParamValueError('The search attribute not supported.') if attr not in FIELD_MAPPING and not attr.startswith(('metric/', 'user_defined/')): raise LineageParamValueError('The search attribute not supported.') if not isinstance(condition, dict): raise LineageParamTypeError("The search_condition element {} should be dict." .format(attr)) for key in condition.keys(): if key not in ["eq", "lt", "gt", "le", "ge", "in"]: raise LineageParamValueError("The compare condition should be in " "('eq', 'lt', 'gt', 'le', 'ge', 'in').") if attr.startswith('metric/'): if len(attr) == 7: raise LineageParamValueError( 'The search attribute not supported.' ) try: SearchModelConditionParameter.check_param_value_type(condition) except ValidationError: raise MindInsightException( error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR, message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr) ) return data
def begin(self, run_context): """ Initialize the training progress when the training job begins. Args: run_context (RunContext): It contains all lineage information, see mindspore.train.callback.RunContext. Raises: MindInsightException: If validating parameter fails. """ log.info('Initialize training lineage collection...') if self.user_defined_info: self.lineage_summary.record_user_defined_info( self.user_defined_info) if not isinstance(run_context, RunContext): error_msg = f'Invalid TrainLineage run_context.' log.error(error_msg) raise LineageParamRunContextError(error_msg) run_context_args = run_context.original_args() if not self.initial_learning_rate: optimizer = run_context_args.get('optimizer') if optimizer and not isinstance(optimizer, Optimizer): log.error( "The parameter optimizer is invalid. It should be an instance of " "mindspore.nn.optim.optimizer.Optimizer.") raise MindInsightException( error=LineageErrors.PARAM_OPTIMIZER_ERROR, message=LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value) if optimizer: log.info('Obtaining initial learning rate...') self.initial_learning_rate = AnalyzeObject.analyze_optimizer( optimizer) log.debug('initial_learning_rate: %s', self.initial_learning_rate) else: network = run_context_args.get('train_network') optimizer = AnalyzeObject.get_optimizer_by_network(network) self.initial_learning_rate = AnalyzeObject.analyze_optimizer( optimizer) log.debug('initial_learning_rate: %s', self.initial_learning_rate) # get train dataset graph train_dataset = run_context_args.get('train_dataset') dataset_graph_dict = ds.serialize(train_dataset) dataset_graph_json_str = json.dumps(dataset_graph_dict, indent=2) dataset_graph_dict = json.loads(dataset_graph_json_str) log.info('Logging dataset graph...') try: self.lineage_summary.record_dataset_graph( dataset_graph=dataset_graph_dict) except Exception as error: error_msg = f'Dataset graph log error in TrainLineage begin: {error}' log.error(error_msg) raise LineageLogError(error_msg) log.info('Dataset graph logged successfully.')
def validate_raise_exception(raise_exception): """ Validate raise_exception. Args: raise_exception (bool): decide raise exception or not, if True, raise exception; else, catch exception and continue. Raises: MindInsightException: If the parameters are invalid. """ if not isinstance(raise_exception, bool): log.error("Invalid raise_exception. It should be True or False.") raise MindInsightException( error=LineageErrors.PARAM_RAISE_EXCEPTION_ERROR, message=LineageErrorMsg.PARAM_RAISE_EXCEPTION_ERROR.value)
def _get_lineage_info(lineage_type, search_condition): """ Get lineage info for dataset or model. Args: lineage_type (str): Lineage type, 'dataset' or 'model'. search_condition (dict): Search condition. Returns: dict, lineage info. Raises: MindInsightException: If method fails to be called. """ if 'lineage_type' in search_condition: raise ParamValueError( "Lineage type does not need to be assigned in a specific interface." ) if lineage_type == 'dataset': search_condition.update({'lineage_type': 'dataset'}) summary_base_dir = str(settings.SUMMARY_BASE_DIR) try: lineage_info = filter_summary_lineage(summary_base_dir, search_condition) lineages = lineage_info['object'] summary_base_dir = os.path.realpath(summary_base_dir) length = len(summary_base_dir) for lineage in lineages: summary_dir = lineage['summary_dir'] summary_dir = os.path.realpath(summary_dir) if summary_base_dir == summary_dir: relative_dir = './' else: relative_dir = os.path.join(os.curdir, summary_dir[length + 1:]) lineage['summary_dir'] = relative_dir except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) return lineage_info
def test_raise_exception_record_trainlineage(self, *args): """Test exception when error happened after recording training infos.""" if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) args[1].side_effect = MindInsightException(error=LineageErrors.PARAM_RUN_CONTEXT_ERROR, message="RunContext error.") train_callback = TrainLineage(SUMMARY_DIR_3, True) train_callback.begin(RunContext(self.run_context)) full_file_name = train_callback.lineage_summary.lineage_log_path file_size1 = os.path.getsize(full_file_name) train_callback.end(RunContext(self.run_context)) file_size2 = os.path.getsize(full_file_name) assert file_size2 > file_size1 eval_callback = EvalLineage(SUMMARY_DIR_3, False) eval_callback.end(RunContext(self.run_context)) file_size3 = os.path.getsize(full_file_name) assert file_size3 == file_size2
def validate_search_model_condition(schema, data): """ Validate search model condition. Args: schema (Schema): Data schema. data (dict): Data to check schema. Raises: MindInsightException: If the parameters are invalid. """ error = schema().validate(data) for error_key in error.keys(): if error_key in SEARCH_MODEL_ERROR_MAPPING.keys(): error_code = SEARCH_MODEL_ERROR_MAPPING.get(error_key) error_msg = SEARCH_MODEL_ERROR_MSG_MAPPING.get(error_key) log.error(error_msg) raise MindInsightException(error=error_code, message=error_msg)
def validate_summary_record(summary_record): """ Validate summary_record. Args: summary_record (SummaryRecord): SummaryRecord is used to record the summary value, and summary_record is an instance of SummaryRecord, see mindspore.train.summary.SummaryRecord Raises: MindInsightException: If the parameters are invalid. """ if not isinstance(summary_record, SummaryRecord): log.error("Invalid summary_record. It should be an instance " "of mindspore.train.summary.SummaryRecord.") raise MindInsightException( error=LineageErrors.PARAM_SUMMARY_RECORD_ERROR, message=LineageErrorMsg.PARAM_SUMMARY_RECORD_ERROR.value)
def get_dataset_graph(): """ Get dataset graph. Returns: str, the dataset graph information. Raises: MindInsightException: If method fails to be called. ParamValueError: If summary_dir is invalid. Examples: >>> GET http://xxxx/v1/mindinsight/datasets/dataset_graph?train_id=xxx """ summary_base_dir = str(settings.SUMMARY_BASE_DIR) summary_dir = get_train_id(request) if summary_dir.startswith('/'): validate_path(summary_dir) elif summary_dir.startswith('./'): summary_dir = os.path.join(summary_base_dir, summary_dir[2:]) summary_dir = validate_path(summary_dir) else: raise ParamValueError("Summary dir should be absolute path or " "relative path that relate to summary base dir.") try: dataset_graph = get_summary_lineage(summary_dir=summary_dir, keys=['dataset_graph']) except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) if dataset_graph: summary_dir_result = dataset_graph.get('summary_dir') base_dir_len = len(summary_base_dir) if summary_base_dir == summary_dir_result: relative_dir = './' else: relative_dir = os.path.join(os.curdir, summary_dir[base_dir_len + 1:]) dataset_graph['summary_dir'] = relative_dir return jsonify(dataset_graph)
def validate_eval_run_context(schema, data): """ Validate mindspore evaluation job run_context data according to schema. Args: schema (Schema): data schema. data (dict): data to check schema. Raises: MindInsightException: If the parameters are invalid. """ errors = schema().validate(data) for error_key, error_msg in errors.items(): if error_key in EVAL_RUN_CONTEXT_ERROR_MAPPING.keys(): error_code = EVAL_RUN_CONTEXT_ERROR_MAPPING.get(error_key) if EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key): error_msg = EVAL_RUN_CONTEXT_ERROR_MSG_MAPPING.get(error_key) log.error(error_msg) raise MindInsightException(error=error_code, message=error_msg)
def validate_file_path(file_path, allow_empty=False): """ Verify that the file_path is valid. Args: file_path (str): Input file path. allow_empty (bool): Whether file_path can be empty. Raises: MindInsightException: If the parameters are invalid. """ try: if allow_empty and not file_path: return file_path return safe_normalize_path(file_path, raise_key='dataset_path', safe_prefixes=None) except ValidationError as error: log.error(str(error)) raise MindInsightException(error=LineageErrors.PARAM_FILE_PATH_ERROR, message=str(error))
def _get_lineage_info(search_condition): """ Get lineage info for dataset or model. Args: search_condition (dict): Search condition. Returns: dict, lineage info. Raises: MindInsightException: If method fails to be called. """ try: lineage_info = filter_summary_lineage(data_manager=DATA_MANAGER, search_condition=search_condition) except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) return lineage_info
def _get_lineage_info(search_condition): """ Get lineage info for dataset or model. Args: search_condition (dict): Search condition. Returns: dict, lineage info. Raises: MindInsightException: If method fails to be called. """ summary_base_dir = str(settings.SUMMARY_BASE_DIR) try: lineage_info = general_filter_summary_lineage( data_manager=DATA_MANAGER, search_condition=search_condition, added=True) lineages = lineage_info['object'] summary_base_dir = os.path.realpath(summary_base_dir) length = len(summary_base_dir) for lineage in lineages: summary_dir = lineage['summary_dir'] summary_dir = os.path.realpath(summary_dir) if summary_base_dir == summary_dir: relative_dir = './' else: relative_dir = os.path.join(os.curdir, summary_dir[length + 1:]) lineage['summary_dir'] = relative_dir except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) return lineage_info
def validate_network(network): """ Verify if the network is valid. Args: network (Cell): See mindspore.nn.Cell. Raises: LineageParamMissingError: If the network is None. MindInsightException: If the network is invalid. """ if not network: error_msg = "The input network for TrainLineage should not be None." log.error(error_msg) raise LineageParamMissingError(error_msg) if not isinstance(network, Cell): log.error("Invalid network. Network should be an instance" "of mindspore.nn.Cell.") raise MindInsightException( error=LineageErrors.PARAM_TRAIN_NETWORK_ERROR, message=LineageErrorMsg.PARAM_TRAIN_NETWORK_ERROR.value)
def get_dataset_graph(): """ Get dataset graph. Returns: str, the dataset graph information. Raises: MindInsightException: If method fails to be called. ParamValueError: If summary_dir is invalid. Examples: >>> GET http://xxxx/v1/mindinsight/datasets/dataset_graph?train_id=xxx """ summary_base_dir = str(settings.SUMMARY_BASE_DIR) summary_dir = get_train_id(request) try: dataset_graph = general_get_summary_lineage(DATA_MANAGER, summary_dir=summary_dir, keys=['dataset_graph']) except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) if dataset_graph: summary_dir_result = dataset_graph.get('summary_dir') base_dir_len = len(summary_base_dir) if summary_base_dir == summary_dir_result: relative_dir = './' else: relative_dir = os.path.join(os.curdir, summary_dir[base_dir_len + 1:]) dataset_graph['summary_dir'] = relative_dir return jsonify(dataset_graph)
def get_dataset_graph(): """ Get dataset graph. Returns: str, the dataset graph information. Raises: MindInsightException: If method fails to be called. ParamValueError: If summary_dir is invalid. Examples: >>> GET http://xxxx/v1/mindinsight/datasets/dataset_graph?train_id=xxx """ train_id = get_train_id(request) validate_train_id(train_id) search_condition = { 'summary_dir': { 'in': [train_id] } } result = {} try: objects = filter_summary_lineage(data_manager=DATA_MANAGER, search_condition=search_condition).get('object') except MindInsightException as exception: raise MindInsightException(exception.error, exception.message, http_code=400) if objects: lineage_obj = objects[0] dataset_graph = lineage_obj.get('dataset_graph') if dataset_graph: result.update({'dataset_graph': dataset_graph}) result.update({'summary_dir': lineage_obj.get('summary_dir')}) return jsonify(result)
def setup_logger(sub_module, log_name, **kwargs): """ Setup logger with sub module name and log file name. Args: sub_module (str): Sub module name, also for sub directory under logroot. log_name (str): Log name, also for log filename. console (bool): Whether to output log to stdout. Default: False. logfile (bool): Whether to output log to disk. Default: True. level (Enum): Log level. Default: INFO. formatter (str): Log format. propagate (bool): Whether to enable propagate feature. Default: False. maxBytes (int): Rotating max bytes. Default: 50M. backupCount (int): Rotating backup count. Default: 30. Returns: Logger, well-configured logger instance. Examples: >>> from mindinsight.utils.log import setup_logger >>> logger = setup_logger('datavisual', 'flask.request', level=logging.DEBUG) >>> from mindinsight.utils.log import get_logger >>> logger = get_logger('datavisual', 'flask.request') >>> import logging >>> logger = logging.getLogger('datavisual.flask.request') """ if kwargs.get('sub_log_name', False): logger = get_logger(sub_module, kwargs['sub_log_name']) else: logger = get_logger(sub_module, log_name) if logger.hasHandlers(): return logger level = kwargs.get('level', settings.LOG_LEVEL) formatter = kwargs.get('formatter', None) propagate = kwargs.get('propagate', False) logger.setLevel(level) logger.propagate = propagate if not formatter: formatter = settings.LOG_FORMAT if kwargs.get('console', False): console_handler = logging.StreamHandler(sys.stdout) console_handler.formatter = MindInsightFormatter(sub_module, formatter) logger.addHandler(console_handler) if kwargs.get('logfile', True): max_bytes = kwargs.get('maxBytes', settings.LOG_ROTATING_MAXBYTES) if not isinstance(max_bytes, int) or not max_bytes > 0: raise MindInsightException(GeneralErrors.PARAM_VALUE_ERROR, 'maxBytes should be int type and > 0.') backup_count = kwargs.get('backupCount', settings.LOG_ROTATING_BACKUPCOUNT) if not isinstance(backup_count, int) or not backup_count > 0: raise MindInsightException( GeneralErrors.PARAM_VALUE_ERROR, 'backupCount should be int type and > 0.') logfile_dir = os.path.join(settings.WORKSPACE, 'log', sub_module) permissions = os.R_OK | os.W_OK | os.X_OK mode = permissions << 6 os.makedirs(logfile_dir, mode=mode, exist_ok=True) logfile_handler = MultiCompatibleRotatingFileHandler( filename=os.path.join(logfile_dir, '{}.{}.log'.format(log_name, settings.PORT)), maxBytes=max_bytes, backupCount=backup_count, encoding='utf8') logfile_handler.formatter = MindInsightFormatter(sub_module, formatter) logger.addHandler(logfile_handler) return logger