Ejemplo n.º 1
0
    def list_tensors(self, train_id, tag):
        """
        List tensors of the given train job and tag.

        If the tensor can not find by the given tag, will raise exception.

        Args:
            train_id (str): ID for train job.
            tag (str): The tag name.

        Returns:
            list, the NameTuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
                the value will contain the given tag data.

        """
        loader_pool = self._get_snapshot_loader_pool()
        if not self._is_loader_in_loader_pool(train_id, loader_pool):
            raise TrainJobNotExistError(
                "Can not find the given train job in cache.")

        data_loader = loader_pool[train_id].data_loader

        tensors = []
        try:
            events_data = data_loader.get_events_data()
            tensors = events_data.tensors(tag)
        except KeyError:
            error_msg = "Can not find any data in this train job by given tag."
            raise ParamValueError(error_msg)
        except AttributeError:
            logger.debug(
                "Train job %r has been deleted or it has not loaded data, "
                "and set tags to empty list.", train_id)

        return tensors
Ejemplo n.º 2
0
    def load(self):
        """Start loading data from the latest summary file to the loader."""
        self.status = _LoaderStatus.LOADING.value
        filenames = []
        for filename in FileHandler.list_dir(self._loader_info['summary_dir']):
            if FileHandler.is_file(
                    FileHandler.join(self._loader_info['summary_dir'],
                                     filename)):
                filenames.append(filename)
        filenames = ExplainLoader._filter_files(filenames)

        if not filenames:
            raise TrainJobNotExistError(
                'No summary file found in %s, explain job will be delete.' %
                self._loader_info['summary_dir'])

        is_end = False
        while not is_end and self.status != _LoaderStatus.STOP.value:
            try:
                file_changed, is_end, event_dict = self._parser.list_events(
                    filenames)
            except UnknownError:
                break

            if file_changed:
                logger.info(
                    'Summary file in %s update, reload the data in the summary.',
                    self._loader_info['summary_dir'])
                self._clear_job()
            if event_dict:
                self._import_data_from_event(event_dict)
Ejemplo n.º 3
0
    def cache_train_job(self, train_id):
        """Cache given train job."""
        loader = None
        need_reload = False
        with self._loader_pool_mutex:
            if self._is_loader_in_loader_pool(train_id, self._loader_pool):
                loader = self._loader_pool.get(train_id)

            if loader is None:
                for generator in self._loader_generators:
                    tmp_loader = generator.generate_loader_by_train_id(
                        train_id)
                    if loader and loader.latest_update_time > tmp_loader.latest_update_time:
                        continue
                    loader = tmp_loader

                if loader is None:
                    raise TrainJobNotExistError(train_id)

                # Update cache status loader to CACHING if loader is NOT_IN_CACHE
                # before triggering the next interval.
                if loader.cache_status == CacheStatus.NOT_IN_CACHE:
                    loader.cache_status = CacheStatus.CACHING

                self._add_loader(loader)
                need_reload = True

        self._update_loader_latest_update_time(loader.loader_id)
        return need_reload
Ejemplo n.º 4
0
def check_train_job_and_profiler_dir(profiler_dir_abs):
    """ check the existence of train_job and profiler dir """
    train_job_dir_abs = os.path.abspath(os.path.join(profiler_dir_abs, '..'))
    if not os.path.exists(train_job_dir_abs):
        raise TrainJobNotExistError(error_detail=train_job_dir_abs)
    if not os.path.exists(profiler_dir_abs):
        raise ProfilerDirNotFoundException(msg=profiler_dir_abs)
Ejemplo n.º 5
0
    def load(self):
        """Start loading data from the latest summary file to the loader."""
        filenames = []
        for filename in FileHandler.list_dir(self._loader_info['summary_dir']):
            if FileHandler.is_file(
                    FileHandler.join(self._loader_info['summary_dir'],
                                     filename)):
                filenames.append(filename)
        filenames = ExplainLoader._filter_files(filenames)

        if not filenames:
            raise TrainJobNotExistError(
                'No summary file found in %s, explain job will be delete.' %
                self._loader_info['summary_dir'])

        is_end = False
        while not is_end:
            is_clean, is_end, event_dict = self._parser.parse_explain(
                filenames)

            if is_clean:
                logger.info(
                    'Summary file in %s update, reload the data in the summary.',
                    self._loader_info['summary_dir'])
                self._clear_job()
            if event_dict:
                self._import_data_from_event(event_dict)
Ejemplo n.º 6
0
 def query_meta(self, train_id):
     """
     Query explain job meta-data.
     Args:
         train_id (str): Job ID.
     Returns:
         dict, the metadata.
     """
     job = self.job_manager.get_job(train_id)
     if job is None:
         raise TrainJobNotExistError(train_id)
     return self._job_2_meta(job)
    def query_hierarchical_occlusion(
        self,
        train_id,
        labels,
        limit,
        offset,
        sorted_name,
        sorted_type,
        prediction_types=None,
        drop_empty=True,
    ):
        """
        Query hierarchical occlusion results.

        Args:
            train_id (str): Job ID.
            labels (list[str]): Label filter.
            limit (int): Maximum number of items to be returned.
            offset (int): Page offset.
            sorted_name (str): Field to be sorted.
            sorted_type (str): Sorting order, 'ascending' or 'descending'.
            prediction_types (list[str]): Prediction types filter.
            drop_empty (bool): Whether to drop out the data without hoc data. Default: True.

        Returns:
            tuple[int, list[dict]], total number of samples after filtering and list of sample results.
        """
        job = self.job_manager.get_job(train_id)
        if job is None:
            raise TrainJobNotExistError(train_id)

        if drop_empty:
            samples = self._query_samples(job,
                                          labels,
                                          sorted_name,
                                          sorted_type,
                                          prediction_types,
                                          drop_type=ExplanationKeys.HOC.value)
        else:
            samples = self._query_samples(job, labels, sorted_name,
                                          sorted_type, prediction_types)

        sample_infos = []
        obj_offset = offset * limit
        count = len(samples)
        end = count
        if obj_offset + limit < end:
            end = obj_offset + limit
        for i in range(obj_offset, end):
            sample = samples[i]
            sample_infos.append(self._touch_sample(sample, job, drop_empty))

        return count, sample_infos
Ejemplo n.º 8
0
def validate_and_normalize_profiler_path(summary_dir, summary_base_dir):
    """
    Validate and normalize profiler path.

    Args:
        summary_dir (str): The relative path of summary directory.
        summary_base_dir (str): The summary base directory.

    Returns:
        str, normalized path of profiler directory.
    """
    profiler_directory_pattern = r'^profiler.*'
    if not summary_dir:
        raise ProfilerParamValueErrorException('The file dir does not exist.')
    try:
        unquote_path = unquote(summary_dir, errors='strict')
    except UnicodeDecodeError:
        raise ProfilerParamValueErrorException(
            'Unquote error with strict mode')
    train_job_dir = os.path.join(summary_base_dir, unquote_path)
    try:
        train_job_dir_abs = validate_and_normalize_path(
            train_job_dir, 'train_job_dir')
    except ValidationError:
        log.error('train_job dir <%s> is invalid', train_job_dir)
        raise ProfilerParamValueErrorException('train_job dir is invalid.')
    if not os.path.exists(train_job_dir_abs):
        raise TrainJobNotExistError(error_detail=train_job_dir_abs)

    try:
        profiler_name_list = []
        for dir_name in os.listdir(train_job_dir_abs):
            search_res = re.search(profiler_directory_pattern, dir_name)
            if search_res:
                profiler_name_list.append(search_res[0])
        profiler_name_list.sort()
        profiler_name_newest = profiler_name_list[-1]
        profiler_dir = os.path.join(summary_base_dir, unquote_path,
                                    profiler_name_newest)
    except ValidationError:
        log.error('no valid profiler dir under <%s>', train_job_dir_abs)
        raise ProfilerDirNotFoundException('Profiler dir not found.')
    try:
        profiler_dir = validate_and_normalize_path(profiler_dir, 'profiler')
    except ValidationError:
        log.error('profiler dir <%s> is invalid', profiler_dir)
        raise ProfilerParamValueErrorException('Profiler dir is invalid.')

    return profiler_dir
Ejemplo n.º 9
0
    def generate_loader_by_train_id(self, train_id):
        """
        Generate loader by train_id.

        Args:
            train_id (str): Train ID of a summary directory, e.g. './log1'.

        Returns:
            dict[str, LoaderStruct], a dict of `Loader`.

        """
        relative_path = self._get_relative_path_from_train_id(train_id)
        try:
            loader = self._generate_loader_by_relative_path(relative_path)
        except PathNotExistError as ex:
            raise TrainJobNotExistError(str(ex))

        return loader
Ejemplo n.º 10
0
    def query_saliency_maps(self,
                            train_id,
                            labels,
                            explainers,
                            limit,
                            offset,
                            sorted_name,
                            sorted_type,
                            prediction_types=None):
        """
        Query saliency maps.

        Args:
            train_id (str): Job ID.
            labels (list[str]): Label filter.
            explainers (list[str]): Explainers of saliency maps to be shown.
            limit (int): Maximum number of items to be returned.
            offset (int): Page offset.
            sorted_name (str): Field to be sorted.
            sorted_type (str): Sorting order, 'ascending' or 'descending'.
            prediction_types (list[str]): Prediction types filter. Default: None.

        Returns:
            tuple[int, list[dict]], total number of samples after filtering and list of sample result.
        """
        job = self.job_manager.get_job(train_id)
        if job is None:
            raise TrainJobNotExistError(train_id)

        samples = self._query_samples(job, labels, sorted_name, sorted_type,
                                      prediction_types)

        sample_infos = []
        obj_offset = offset * limit
        count = len(samples)
        end = count
        if obj_offset + limit < end:
            end = obj_offset + limit
        for i in range(obj_offset, end):
            sample = samples[i]
            sample_infos.append(self._touch_sample(sample, job, explainers))

        return count, sample_infos
Ejemplo n.º 11
0
    def _check_train_job_exist(self, train_id, loader_pool):
        """
        Check train job exist, if not exist, will raise exception.

        Args:
            train_id (str): The given train job id.
            loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.

        Raises:
            TrainJobNotExistError: Can not find train job in data manager.
        """
        is_exist = False
        if train_id in loader_pool:
            return
        for generator in self._loader_generators:
            if generator.check_train_job_exist(train_id):
                is_exist = True
                break
        if not is_exist:
            raise TrainJobNotExistError("Can not find the train job in data manager.")
Ejemplo n.º 12
0
    def cache_train_job(self, train_id):
        """Cache given train job."""
        loader = None
        need_reload = False
        with self._loader_pool_mutex:
            if self._is_loader_in_loader_pool(train_id, self._loader_pool):
                loader = self._loader_pool.get(train_id)

            if loader is None:
                for generator in self._loader_generators:
                    tmp_loader = generator.generate_loader_by_train_id(train_id)
                    if loader and loader.latest_update_time > tmp_loader.latest_update_time:
                        continue
                    loader = tmp_loader

                if loader is None:
                    raise TrainJobNotExistError(train_id)

                self._add_loader(loader)
                need_reload = True

        self._update_loader_latest_update_time(loader.loader_id)
        return need_reload
Ejemplo n.º 13
0
 def get_train_job(self, train_id):
     """Get cached train job."""
     try:
         return self._cache_items[train_id]
     except KeyError:
         raise TrainJobNotExistError(train_id)
Ejemplo n.º 14
0
    def query_saliency_maps(self, train_id, labels, explainers, limit, offset,
                            sorted_name, sorted_type):
        """
        Query saliency maps.
        Args:
            train_id (str): Job ID.
            labels (list[str]): Label filter.
            explainers (list[str]): Explainers of saliency maps to be shown.
            limit (int): Max. no. of items to be returned.
            offset (int): Page offset.
            sorted_name (str): Field to be sorted.
            sorted_type (str): Sorting order, 'ascending' or 'descending'.

        Returns:
            tuple[int, list[dict]], total no. of samples after filtering and
                list of sample result.
        """
        job = self.job_manager.get_job(train_id)
        if job is None:
            raise TrainJobNotExistError(train_id)

        samples = copy.deepcopy(job.get_all_samples())
        if labels:
            filtered = []
            for sample in samples:
                infer_labels = [
                    inference["label"] for inference in sample["inferences"]
                ]
                for infer_label in infer_labels:
                    if infer_label in labels:
                        filtered.append(sample)
                        break
            samples = filtered

        reverse = sorted_type == "descending"
        if sorted_name == "confidence":
            if reverse:
                samples.sort(key=_sort_key_max_confidence, reverse=reverse)
            else:
                samples.sort(key=_sort_key_min_confidence, reverse=reverse)
        elif sorted_name == "uncertainty":
            if not job.uncertainty_enabled:
                raise ParamValueError(
                    "Uncertainty is not enabled, sorted_name cannot be 'uncertainty'"
                )
            if reverse:
                samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse)
            else:
                samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse)
        elif sorted_name != "":
            raise ParamValueError("sorted_name")

        sample_infos = []
        obj_offset = offset * limit
        count = len(samples)
        end = count
        if obj_offset + limit < end:
            end = obj_offset + limit
        for i in range(obj_offset, end):
            sample = samples[i]
            sample_infos.append(self._touch_sample(sample, job, explainers))

        return count, sample_infos
Ejemplo n.º 15
0
 def query_explainer_scores(self, train_id):
     """Query evaluation scores."""
     job = self.job_manager.get_job(train_id)
     if job is None:
         raise TrainJobNotExistError(train_id)
     return copy.deepcopy(job.explainer_scores)